diff --git a/.github/workflows/linux-build.yml b/.github/workflows/linux-build.yml deleted file mode 100644 index 9f9a81d27d9..00000000000 --- a/.github/workflows/linux-build.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: Build Linux SpacetimeDB CLI - -on: - push: - tags: - - '**' - branches: - - master - - release/* - -jobs: - linux-amd64-cli: - runs-on: bare-metal - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Compile x86 - run: | - export PATH=$HOME/.cargo/bin:$PATH - cargo build --release -p spacetimedb-cli - mkdir build - cp --sparse=never target/release/spacetime build/spacetime - cd build && tar -czf spacetime.linux-amd64.tar.gz spacetime - rm spacetime - - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - - name: Upload to DO Spaces - uses: shallwefootball/s3-upload-action@master - with: - aws_key_id: ${{ secrets.AWS_KEY_ID }} - aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY}} - aws_bucket: ${{ vars.AWS_BUCKET }} - source_dir: build - endpoint: https://nyc3.digitaloceanspaces.com - destination_dir: ${{ steps.extract_branch.outputs.branch }} - - linux-arm64-cli: - runs-on: arm-runner - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Compile ARM64 - run: | - export PATH=$HOME/.cargo/bin:$PATH - cargo build --release -p spacetimedb-cli - mkdir build - cp --sparse=never target/release/spacetime build/spacetime - cd build && tar -czf spacetime.linux-arm64.tar.gz spacetime - rm spacetime - - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - - name: Upload to DO Spaces - uses: shallwefootball/s3-upload-action@master - with: - aws_key_id: ${{ secrets.AWS_KEY_ID }} - aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY}} - aws_bucket: ${{ vars.AWS_BUCKET }} - source_dir: build - endpoint: https://nyc3.digitaloceanspaces.com - destination_dir: ${{ steps.extract_branch.outputs.branch }} diff --git a/.github/workflows/macos-build.yml b/.github/workflows/macos-build.yml deleted file mode 100644 index 427b256bd4e..00000000000 --- a/.github/workflows/macos-build.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Build MacOS SpacetimeDB CLI - -on: - push: - tags: - - '**' - branches: - - master - - release/* - -jobs: - macos-cli: - runs-on: macos-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Add cross-platform targets - run: | - rustup target add aarch64-apple-darwin - rustup target add x86_64-apple-darwin - - - name: Compile x86 - run: | - cargo build --release -p spacetimedb-cli --target x86_64-apple-darwin - mkdir build - file target/x86_64-apple-darwin/release/spacetime - dd if=target/x86_64-apple-darwin/release/spacetime of=build/spacetime conv=noerror,sync - chmod +x build/spacetime - cd build && tar -czf spacetime.darwin-amd64.tar.gz spacetime - rm spacetime - - - name: Compile Aarch64 - run: | - cargo build --release -p spacetimedb-cli --target=aarch64-apple-darwin - file target/aarch64-apple-darwin/release/spacetime - # dd is used to avoid incompatibilities between BSD tar vs. GNU tar - dd if=target/aarch64-apple-darwin/release/spacetime of=build/spacetime conv=noerror,sync - chmod +x build/spacetime - cd build && tar -czf spacetime.darwin-arm64.tar.gz spacetime - rm spacetime - - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - - name: Upload to DO Spaces - uses: shallwefootball/s3-upload-action@master - with: - aws_key_id: ${{ secrets.AWS_KEY_ID }} - aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY}} - aws_bucket: ${{ vars.AWS_BUCKET }} - source_dir: build - endpoint: https://nyc3.digitaloceanspaces.com - destination_dir: ${{ steps.extract_branch.outputs.branch }} diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml new file mode 100644 index 00000000000..640979c8f93 --- /dev/null +++ b/.github/workflows/package.yml @@ -0,0 +1,70 @@ +name: Package SpacetimeDB CLI + +on: + push: + tags: + - '**' + branches: + - master + - release/* + +jobs: + build-cli: + strategy: + fail-fast: false + matrix: + include: + - { name: x86_64 Linux, target: x86_64-unknown-linux-gnu, runner: bare-metal } + - { name: aarch64 Linux, target: aarch64-unknown-linux-gnu, runner: arm-runner } + - { name: aarch64 macOS, target: aarch64-apple-darwin, runner: macos-latest } + - { name: x86_64 macOS, target: x86_64-apple-darwin, runner: macos-latest } + - { name: x86_64 Windows, target: x86_64-pc-windows-msvc, runner: windows-latest } + + name: Build CLI for ${{ matrix.name }} + runs-on: ${{ matrix.runner }} + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Show arch + run: uname -a + + - name: Install Rust + uses: dsherret/rust-toolchain-file@v1 + + - name: Install rust target + run: rustup target add ${{ matrix.target }} + + - name: Compile + run: | + cargo build --release --target ${{ matrix.target }} -p spacetimedb-cli -p spacetimedb-standalone -p spacetimedb-update + + - name: Package (unix) + if: ${{ runner.os != 'Windows' }} + run: | + mkdir build + cd target/${{matrix.target}}/release + tar -czf ../../../build/spacetime-${{matrix.target}}.tar.gz spacetimedb-{cli,standalone,update} + + - name: Package (windows) + if: ${{ runner.os == 'Windows' }} + run: | + mkdir build + cd target/${{matrix.target}}/release + 7z a ../../../build/spacetime-${{matrix.target}}.zip spacetimedb-cli.exe spacetimedb-standalone.exe spacetimedb-update.exe + + - name: Extract branch name + shell: bash + run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + id: extract_branch + + - name: Upload to DO Spaces + uses: shallwefootball/s3-upload-action@master + with: + aws_key_id: ${{ secrets.AWS_KEY_ID }} + aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY}} + aws_bucket: ${{ vars.AWS_BUCKET }} + source_dir: build + endpoint: https://nyc3.digitaloceanspaces.com + destination_dir: ${{ steps.extract_branch.outputs.branch }} diff --git a/.github/workflows/windows-build.yml b/.github/workflows/windows-build.yml deleted file mode 100644 index 8fbd0d6bb50..00000000000 --- a/.github/workflows/windows-build.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: Build Windows SpacetimeDB CLI - -on: - push: - tags: - - '**' - branches: - - master - - release/* - -jobs: - windows-cli: - runs-on: windows-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Show arch - run: uname -a - - - name: Compile - run: | - cargo build --release -p spacetimedb-cli - mkdir build - mv target/release/spacetime.exe build/spacetime.exe - - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - - name: Upload to DO Spaces - uses: shallwefootball/s3-upload-action@master - with: - aws_key_id: ${{ secrets.AWS_KEY_ID }} - aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY}} - aws_bucket: ${{ vars.AWS_BUCKET }} - source_dir: build - endpoint: https://nyc3.digitaloceanspaces.com - destination_dir: ${{ steps.extract_branch.outputs.branch }} diff --git a/Cargo.lock b/Cargo.lock index 032b0882fd5..d07876ad5dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3599,9 +3599,9 @@ dependencies = [ [[package]] name = "proptest-derive" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff7ff745a347b87471d859a377a9a404361e7efc2a971d73424a6d183c0fc77" +checksum = "4ee1c9ac207483d5e7db4940700de86a9aae46ef90c48b57f99fe7edb8345e49" dependencies = [ "proc-macro2", "quote", @@ -5004,6 +5004,7 @@ dependencies = [ "spacetimedb-lib", "spacetimedb-physical-plan", "spacetimedb-primitives", + "spacetimedb-sats", "spacetimedb-sql-parser", "spacetimedb-table", ] diff --git a/crates/cli/src/common_args.rs b/crates/cli/src/common_args.rs index 9ffb0ab40db..a238ecbf0d8 100644 --- a/crates/cli/src/common_args.rs +++ b/crates/cli/src/common_args.rs @@ -27,5 +27,5 @@ pub fn yes() -> Arg { .long("yes") .short('y') .action(SetTrue) - .help("Assume \"yes\" as answer to all prompts and run non-interactively") + .help("Run non-interactively wherever possible. This will answer \"yes\" to almost all prompts, but will sometimes answer \"no\" to preserve non-interactivity (e.g. when prompting whether to log in with spacetimedb.com).") } diff --git a/crates/cli/src/config.rs b/crates/cli/src/config.rs index 069cebb7ea6..f0cde779513 100644 --- a/crates/cli/src/config.rs +++ b/crates/cli/src/config.rs @@ -820,14 +820,6 @@ Update the server's fingerprint with: pub fn spacetimedb_token(&self) -> Option<&String> { self.home.spacetimedb_token.as_ref() } - - pub fn spacetimedb_token_or_error(&self) -> anyhow::Result<&String> { - if let Some(token) = self.spacetimedb_token() { - Ok(token) - } else { - Err(anyhow::anyhow!("No login token found. Please run `spacetime login`.")) - } - } } #[cfg(test)] diff --git a/crates/cli/src/subcommands/call.rs b/crates/cli/src/subcommands/call.rs index f9501f3d993..335b514c7c4 100644 --- a/crates/cli/src/subcommands/call.rs +++ b/crates/cli/src/subcommands/call.rs @@ -30,6 +30,7 @@ pub fn cli() -> clap::Command { .arg(Arg::new("arguments").help("arguments formatted as JSON").num_args(1..)) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) .arg(common_args::anonymous()) + .arg(common_args::yes()) .after_help("Run `spacetime help call` for more detailed information.\n") } @@ -38,6 +39,7 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), Error> { let reducer_name = args.get_one::("reducer_name").unwrap(); let arguments = args.get_many::("arguments"); let server = args.get_one::("server").map(|s| s.as_ref()); + let force = args.get_flag("force"); let anon_identity = args.get_flag("anon_identity"); @@ -49,7 +51,7 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), Error> { database_identity.clone(), reducer_name )); - let auth_header = get_auth_header(&config, anon_identity)?; + let auth_header = get_auth_header(&mut config, anon_identity, server, !force).await?; let builder = add_auth_header_opt(builder, &auth_header); let describe_reducer = util::describe_reducer( &mut config, @@ -57,6 +59,7 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), Error> { server.map(|x| x.to_string()), reducer_name.clone(), anon_identity, + !force, ) .await?; diff --git a/crates/cli/src/subcommands/delete.rs b/crates/cli/src/subcommands/delete.rs index b73f8ba47a6..d2a4aea13d7 100644 --- a/crates/cli/src/subcommands/delete.rs +++ b/crates/cli/src/subcommands/delete.rs @@ -12,17 +12,19 @@ pub fn cli() -> clap::Command { .help("The name or identity of the database to delete"), ) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) + .arg(common_args::yes()) .after_help("Run `spacetime help delete` for more detailed information.\n") } -pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let server = args.get_one::("server").map(|s| s.as_ref()); let database = args.get_one::("database").unwrap(); + let force = args.get_flag("force"); let identity = database_identity(&config, database, server).await?; let builder = reqwest::Client::new().post(format!("{}/database/delete/{}", config.get_host_url(server)?, identity)); - let auth_header = get_auth_header(&config, false)?; + let auth_header = get_auth_header(&mut config, false, server, !force).await?; let builder = add_auth_header_opt(builder, &auth_header); builder.send().await?.error_for_status()?; diff --git a/crates/cli/src/subcommands/describe.rs b/crates/cli/src/subcommands/describe.rs index 95bc52ecafd..0cf6a7f2a38 100644 --- a/crates/cli/src/subcommands/describe.rs +++ b/crates/cli/src/subcommands/describe.rs @@ -23,14 +23,16 @@ pub fn cli() -> clap::Command { ) .arg(common_args::anonymous()) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) + .arg(common_args::yes()) .after_help("Run `spacetime help describe` for more detailed information.\n") } -pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let database = args.get_one::("database").unwrap(); let entity_name = args.get_one::("entity_name"); let entity_type = args.get_one::("entity_type"); let server = args.get_one::("server").map(|s| s.as_ref()); + let force = args.get_flag("force"); let anon_identity = args.get_flag("anon_identity"); @@ -46,7 +48,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error entity_name ), }); - let auth_header = get_auth_header(&config, anon_identity)?; + let auth_header = get_auth_header(&mut config, anon_identity, server, !force).await?; let builder = add_auth_header_opt(builder, &auth_header); let descr = builder.send().await?.error_for_status()?.text().await?; diff --git a/crates/cli/src/subcommands/dns.rs b/crates/cli/src/subcommands/dns.rs index e14a363c6b7..b270aa271b2 100644 --- a/crates/cli/src/subcommands/dns.rs +++ b/crates/cli/src/subcommands/dns.rs @@ -1,6 +1,8 @@ use crate::common_args; use crate::config::Config; -use crate::util::{add_auth_header_opt, decode_identity, get_auth_header, spacetime_register_tld}; +use crate::util::{ + add_auth_header_opt, decode_identity, get_auth_header, get_login_token_or_log_in, spacetime_register_tld, +}; use clap::ArgMatches; use clap::{Arg, Command}; use reqwest::Url; @@ -22,17 +24,20 @@ pub fn cli() -> Command { .help("The database identity to rename"), ) .arg(common_args::server().help("The nickname, host name or URL of the server on which to set the name")) + .arg(common_args::yes()) .after_help("Run `spacetime rename --help` for more detailed information.\n") } -pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let domain = args.get_one::("new-name").unwrap(); let database_identity = args.get_one::("database-identity").unwrap(); let server = args.get_one::("server").map(|s| s.as_ref()); - let identity = decode_identity(&config)?; - let auth_header = get_auth_header(&config, false)?; + let force = args.get_flag("force"); + let token = get_login_token_or_log_in(&mut config, server, !force).await?; + let identity = decode_identity(&token)?; + let auth_header = get_auth_header(&mut config, false, server, !force).await?; - match spacetime_register_tld(&config, domain, server).await? { + match spacetime_register_tld(&mut config, domain, server, !force).await? { RegisterTldResult::Success { domain } => { println!("Registered domain: {}", domain); } diff --git a/crates/cli/src/subcommands/energy.rs b/crates/cli/src/subcommands/energy.rs index 117041ecf79..555b4e7a5a5 100644 --- a/crates/cli/src/subcommands/energy.rs +++ b/crates/cli/src/subcommands/energy.rs @@ -3,7 +3,7 @@ use crate::common_args; use clap::ArgMatches; use crate::config::Config; -use crate::util; +use crate::util::{self, get_login_token_or_log_in}; pub fn cli() -> clap::Command { clap::Command::new("energy") @@ -26,7 +26,8 @@ fn get_energy_subcommands() -> Vec { .arg( common_args::server() .help("The nickname, host name or URL of the server from which to request balance information"), - )] + ) + .arg(common_args::yes())] } async fn exec_subcommand(config: Config, cmd: &str, args: &ArgMatches) -> Result<(), anyhow::Error> { @@ -41,15 +42,17 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error exec_subcommand(config, cmd, subcommand_args).await } -async fn exec_status(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +async fn exec_status(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { // let project_name = args.value_of("project name").unwrap(); let identity = args.get_one::("identity"); let server = args.get_one::("server").map(|s| s.as_ref()); + let force = args.get_flag("force"); // TODO: We should remove the ability to call this for arbitrary users. At *least* remove it from the CLI. let identity = if let Some(identity) = identity { identity.clone() } else { - util::decode_identity(&config)? + let token = get_login_token_or_log_in(&mut config, server, !force).await?; + util::decode_identity(&token)? }; let status = reqwest::Client::new() diff --git a/crates/cli/src/subcommands/generate/rust.rs b/crates/cli/src/subcommands/generate/rust.rs index 259620f72f9..fb8533c25b7 100644 --- a/crates/cli/src/subcommands/generate/rust.rs +++ b/crates/cli/src/subcommands/generate/rust.rs @@ -1426,6 +1426,7 @@ impl __sdk::EventContext for EventContext {{ /// A handle on a subscribed query. // TODO: Document this better after implementing the new subscription API. +#[derive(Clone)] pub struct SubscriptionHandle {{ imp: __sdk::SubscriptionHandleImpl, }} @@ -1438,6 +1439,27 @@ impl __sdk::SubscriptionHandle for SubscriptionHandle {{ fn new(imp: __sdk::SubscriptionHandleImpl) -> Self {{ Self {{ imp }} }} + + /// Returns true if this subscription has been terminated due to an unsubscribe call or an error. + fn is_ended(&self) -> bool {{ + self.imp.is_ended() + }} + + /// Returns true if this subscription has been applied and has not yet been unsubscribed. + fn is_active(&self) -> bool {{ + self.imp.is_active() + }} + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`, + /// then run `on_end` when its rows are removed from the client cache. + fn unsubscribe_then(self, on_end: __sdk::OnEndedCallback) -> __anyhow::Result<()> {{ + self.imp.unsubscribe_then(Some(on_end)) + }} + + fn unsubscribe(self) -> __anyhow::Result<()> {{ + self.imp.unsubscribe_then(None) + }} + }} /// Alias trait for a [`__sdk::DbContext`] connected to this module, diff --git a/crates/cli/src/subcommands/list.rs b/crates/cli/src/subcommands/list.rs index b1d6ad7526f..861e05de334 100644 --- a/crates/cli/src/subcommands/list.rs +++ b/crates/cli/src/subcommands/list.rs @@ -1,5 +1,6 @@ use crate::common_args; use crate::util; +use crate::util::get_login_token_or_log_in; use crate::Config; use clap::{ArgMatches, Command}; use reqwest::StatusCode; @@ -14,6 +15,7 @@ pub fn cli() -> Command { Command::new("list") .about("Lists the databases attached to an identity") .arg(common_args::server().help("The nickname, host name or URL of the server from which to list databases")) + .arg(common_args::yes()) } #[derive(Deserialize)] @@ -27,9 +29,11 @@ struct IdentityRow { pub db_identity: Identity, } -pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let server = args.get_one::("server").map(|s| s.as_ref()); - let identity = util::decode_identity(&config)?; + let force = args.get_flag("force"); + let token = get_login_token_or_log_in(&mut config, server, !force).await?; + let identity = util::decode_identity(&token)?; let client = reqwest::Client::new(); let res = client @@ -38,7 +42,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error config.get_host_url(server)?, identity )) - .basic_auth("token", Some(config.spacetimedb_token_or_error()?)) + .basic_auth("token", Some(token)) .send() .await?; diff --git a/crates/cli/src/subcommands/login.rs b/crates/cli/src/subcommands/login.rs index 4276e8b7117..d4d49241483 100644 --- a/crates/cli/src/subcommands/login.rs +++ b/crates/cli/src/subcommands/login.rs @@ -5,6 +5,8 @@ use reqwest::Url; use serde::Deserialize; use webbrowser; +pub const DEFAULT_AUTH_HOST: &str = "https://spacetimedb.com"; + pub fn cli() -> Command { Command::new("login") .args_conflicts_with_subcommands(true) @@ -13,7 +15,7 @@ pub fn cli() -> Command { .arg( Arg::new("auth-host") .long("auth-host") - .default_value("https://spacetimedb.com") + .default_value(DEFAULT_AUTH_HOST) .group("login-method") .help("Fetch login token from a different host"), ) @@ -79,13 +81,17 @@ async fn exec_subcommand(config: Config, cmd: &str, args: &ArgMatches) -> Result async fn exec_show(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let include_token = args.get_flag("token"); - let identity = decode_identity(&config)?; + let token = if let Some(token) = config.spacetimedb_token() { + token + } else { + println!("You are not logged in. Run `spacetime login` to log in."); + return Ok(()); + }; + + let identity = decode_identity(token)?; println!("You are logged in as {}", identity); if include_token { - // We can `unwrap` because `decode_identity` fetches this too. - // TODO: maybe decode_identity should take token as a param. - let token = config.spacetimedb_token().unwrap(); println!("Your auth token (don't share this!) is {}", token); } @@ -100,18 +106,26 @@ async fn spacetimedb_token_cached(config: &mut Config, host: &Url, direct_login: println!("If you want to log out, use spacetime logout."); Ok(token.clone()) } else { - let token = if direct_login { - spacetimedb_direct_login(host).await? - } else { - let session_token = web_login_cached(config, host).await?; - spacetimedb_login(host, &session_token).await? - }; - config.set_spacetimedb_token(token.clone()); - config.save(); - Ok(token) + spacetimedb_login_force(config, host, direct_login).await } } +pub async fn spacetimedb_login_force(config: &mut Config, host: &Url, direct_login: bool) -> anyhow::Result { + let token = if direct_login { + let token = spacetimedb_direct_login(host).await?; + println!("We have logged in directly to your target server."); + println!("WARNING: This login will NOT work for any other servers."); + token + } else { + let session_token = web_login_cached(config, host).await?; + spacetimedb_login(host, &session_token).await? + }; + config.set_spacetimedb_token(token.clone()); + config.save(); + + Ok(token) +} + async fn web_login_cached(config: &mut Config, host: &Url) -> anyhow::Result { if let Some(session_token) = config.web_session_token() { // Currently, these session tokens do not expire. At some point in the future, we may also need to check this session token for validity. diff --git a/crates/cli/src/subcommands/logs.rs b/crates/cli/src/subcommands/logs.rs index 887af4a39cf..8a3c7f3d4d4 100644 --- a/crates/cli/src/subcommands/logs.rs +++ b/crates/cli/src/subcommands/logs.rs @@ -47,6 +47,7 @@ pub fn cli() -> clap::Command { .value_parser(clap::value_parser!(Format)) .help("Output format for the logs") ) + .arg(common_args::yes()) .after_help("Run `spacetime help logs` for more detailed information.\n") } @@ -109,14 +110,15 @@ impl clap::ValueEnum for Format { } } -pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let server = args.get_one::("server").map(|s| s.as_ref()); + let force = args.get_flag("force"); let mut num_lines = args.get_one::("num_lines").copied(); let database = args.get_one::("database").unwrap(); let follow = args.get_flag("follow"); let format = *args.get_one::("format").unwrap(); - let auth_header = get_auth_header(&config, false)?; + let auth_header = get_auth_header(&mut config, false, server, !force).await?; let database_identity = database_identity(&config, database, server).await?; diff --git a/crates/cli/src/subcommands/publish.rs b/crates/cli/src/subcommands/publish.rs index 64183173d89..20b904afd39 100644 --- a/crates/cli/src/subcommands/publish.rs +++ b/crates/cli/src/subcommands/publish.rs @@ -65,7 +65,7 @@ pub fn cli() -> clap::Command { .after_help("Run `spacetime help publish` for more detailed information.") } -pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { +pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { let server = args.get_one::("server").map(|s| s.as_str()); let name_or_identity = args.get_one::("name|identity"); let path_to_project = args.get_one::("project_path").unwrap(); @@ -80,7 +80,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error // we want to use the default identity // TODO(jdetter): We should maybe have some sort of user prompt here for them to be able to // easily create a new identity with an email - let auth_header = get_auth_header(&config, anon_identity)?; + let auth_header = get_auth_header(&mut config, anon_identity, server, !force).await?; let mut query_params = Vec::<(&str, &str)>::new(); query_params.push(("host_type", "wasm")); @@ -159,7 +159,9 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error let res = builder.body(program_bytes).send().await?; if res.status() == StatusCode::UNAUTHORIZED && !anon_identity { - let identity = decode_identity(&config)?; + // If we're not in the `anon_identity` case, then we have already forced the user to log in above (using `get_auth_header`), so this should be safe to unwrap. + let token = config.spacetimedb_token().unwrap(); + let identity = decode_identity(token)?; let err = res.text().await?; return unauth_error_context( Err(anyhow::anyhow!(err)), @@ -198,7 +200,16 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error )); } PublishResult::PermissionDenied { domain } => { - let identity = decode_identity(&config)?; + if anon_identity { + anyhow::bail!( + "You need to be logged in as the owner of {} to publish to {}", + domain.tld(), + domain.tld() + ); + } + // If we're not in the `anon_identity` case, then we have already forced the user to log in above (using `get_auth_header`), so this should be safe to unwrap. + let token = config.spacetimedb_token().unwrap(); + let identity = decode_identity(token)?; //TODO(jdetter): Have a nice name generator here, instead of using some abstract characters // we should perhaps generate fun names like 'green-fire-dragon' instead let suggested_tld: String = identity.chars().take(12).collect(); diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index dc3d7d62ba1..aa11285f7eb 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -37,16 +37,18 @@ pub fn cli() -> clap::Command { ) .arg(common_args::anonymous()) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) + .arg(common_args::yes()) } -pub(crate) async fn parse_req(config: Config, args: &ArgMatches) -> Result { +pub(crate) async fn parse_req(mut config: Config, args: &ArgMatches) -> Result { let server = args.get_one::("server").map(|s| s.as_ref()); + let force = args.get_flag("force"); let database_name_or_identity = args.get_one::("database").unwrap(); let anon_identity = args.get_flag("anon_identity"); Ok(Connection { host: config.get_host_url(server)?, - auth_header: get_auth_header(&config, anon_identity)?, + auth_header: get_auth_header(&mut config, anon_identity, server, !force).await?, database_identity: database_identity(&config, database_name_or_identity, server).await?, database: database_name_or_identity.to_string(), }) diff --git a/crates/cli/src/subcommands/subscribe.rs b/crates/cli/src/subcommands/subscribe.rs index 46b0f4f270e..a3070e65dd5 100644 --- a/crates/cli/src/subcommands/subscribe.rs +++ b/crates/cli/src/subcommands/subscribe.rs @@ -63,6 +63,7 @@ pub fn cli() -> clap::Command { .help("Print the initial update for the queries."), ) .arg(common_args::anonymous()) + .arg(common_args::yes()) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) } diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 9d9778041dd..0c1da694904 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -3,7 +3,7 @@ use base64::{ engine::general_purpose::STANDARD as BASE_64_STD, engine::general_purpose::STANDARD_NO_PAD as BASE_64_STD_NO_PAD, Engine as _, }; -use reqwest::RequestBuilder; +use reqwest::{RequestBuilder, Url}; use serde::Deserialize; use spacetimedb::auth::identity::{IncomingClaims, SpacetimeIdentityClaims}; use spacetimedb_client_api_messages::name::{DnsLookupResponse, RegisterTldResult, ReverseDNSResponse}; @@ -13,6 +13,7 @@ use std::io::Write; use std::path::Path; use crate::config::Config; +use crate::login::{spacetimedb_login_force, DEFAULT_AUTH_HOST}; /// Determine the identity of the `database`. pub async fn database_identity( @@ -46,11 +47,12 @@ pub async fn spacetime_dns( /// identity will be looked up in the config and it will be used instead. Returns Ok() if the /// domain is successfully registered, returns Err otherwise. pub async fn spacetime_register_tld( - config: &Config, + config: &mut Config, tld: &str, server: Option<&str>, + interactive: bool, ) -> Result { - let auth_header = get_auth_header(config, false)?; + let auth_header = get_auth_header(config, false, server, interactive).await?; // TODO(jdetter): Fix URL encoding on specifying this domain let builder = reqwest::Client::new() @@ -128,6 +130,7 @@ pub async fn describe_reducer( server: Option, reducer_name: String, anon_identity: bool, + interactive: bool, ) -> anyhow::Result { let builder = reqwest::Client::new().get(format!( "{}/database/schema/{}/{}/{}", @@ -136,7 +139,7 @@ pub async fn describe_reducer( "reducer", reducer_name )); - let auth_header = get_auth_header(config, anon_identity)?; + let auth_header = get_auth_header(config, anon_identity, server.as_deref(), interactive).await?; let builder = add_auth_header_opt(builder, &auth_header); let descr = builder @@ -170,11 +173,16 @@ pub fn add_auth_header_opt(mut builder: RequestBuilder, auth_header: &Option anyhow::Result> { +pub async fn get_auth_header( + config: &mut Config, + anon_identity: bool, + target_server: Option<&str>, + interactive: bool, +) -> anyhow::Result> { if anon_identity { Ok(None) } else { - let token = config.spacetimedb_token_or_error()?; + let token = get_login_token_or_log_in(config, target_server, interactive).await?; // The current form is: Authorization: Basic base64("token:") let mut auth_header = String::new(); auth_header.push_str(format!("Basic {}", BASE_64_STD.encode(format!("token:{}", token))).as_str()); @@ -264,8 +272,7 @@ Please log back in with `spacetime logout` and then `spacetime login`." }) } -pub fn decode_identity(config: &Config) -> anyhow::Result { - let token = config.spacetimedb_token_or_error()?; +pub fn decode_identity(token: &String) -> anyhow::Result { // Here, we manually extract and decode the claims from the json web token. // We do this without using the `jsonwebtoken` crate because it doesn't seem to have a way to skip signature verification. // But signature verification would require getting the public key from a server, and we don't necessarily want to do that. @@ -281,3 +288,30 @@ pub fn decode_identity(config: &Config) -> anyhow::Result { Ok(claims_data.identity.to_string()) } + +pub async fn get_login_token_or_log_in( + config: &mut Config, + target_server: Option<&str>, + interactive: bool, +) -> anyhow::Result { + if let Some(token) = config.spacetimedb_token() { + return Ok(token.clone()); + } + + // Note: We pass `force: false` to `y_or_n` because if we're running non-interactively we want to default to "no", not yes! + let full_login = interactive + && y_or_n( + false, + // It would be "ideal" if we could print the `spacetimedb.com` by deriving it from the `default_auth_host` constant, + // but this will change _so_ infrequently that it's not even worth the time to write that code and test it. + "You are not logged in. Would you like to log in with spacetimedb.com?", + )?; + + if full_login { + let host = Url::parse(DEFAULT_AUTH_HOST)?; + spacetimedb_login_force(config, &host, false).await + } else { + let host = Url::parse(&config.get_host_url(target_server)?)?; + spacetimedb_login_force(config, &host, true).await + } +} diff --git a/crates/cli/tests/snapshots/codegen__codegen_rust.snap b/crates/cli/tests/snapshots/codegen__codegen_rust.snap index fb9309449ce..309115a88da 100644 --- a/crates/cli/tests/snapshots/codegen__codegen_rust.snap +++ b/crates/cli/tests/snapshots/codegen__codegen_rust.snap @@ -1544,6 +1544,7 @@ impl __sdk::EventContext for EventContext { /// A handle on a subscribed query. // TODO: Document this better after implementing the new subscription API. +#[derive(Clone)] pub struct SubscriptionHandle { imp: __sdk::SubscriptionHandleImpl, } @@ -1556,6 +1557,27 @@ impl __sdk::SubscriptionHandle for SubscriptionHandle { fn new(imp: __sdk::SubscriptionHandleImpl) -> Self { Self { imp } } + + /// Returns true if this subscription has been terminated due to an unsubscribe call or an error. + fn is_ended(&self) -> bool { + self.imp.is_ended() + } + + /// Returns true if this subscription has been applied and has not yet been unsubscribed. + fn is_active(&self) -> bool { + self.imp.is_active() + } + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`, + /// then run `on_end` when its rows are removed from the client cache. + fn unsubscribe_then(self, on_end: __sdk::OnEndedCallback) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(Some(on_end)) + } + + fn unsubscribe(self) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(None) + } + } /// Alias trait for a [`__sdk::DbContext`] connected to this module, diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 32fe34bcd3d..be3b53c680b 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -59,6 +59,17 @@ impl> RowListLen for L { } } +pub trait ByteListLen { + /// Returns the uncompressed size of the list in bytes + fn num_bytes(&self) -> usize; +} + +impl ByteListLen for Vec { + fn num_bytes(&self) -> usize { + self.iter().map(|str| str.len()).sum() + } +} + /// A format / codec used by the websocket API. /// /// This can be e.g., BSATN, JSON. @@ -67,7 +78,14 @@ pub trait WebsocketFormat: Sized { type Single: SpacetimeType + for<'de> Deserialize<'de> + Serialize + Debug + Clone; /// The type used for the encoding of a list of items. - type List: SpacetimeType + for<'de> Deserialize<'de> + Serialize + RowListLen + Debug + Clone + Default; + type List: SpacetimeType + + for<'de> Deserialize<'de> + + Serialize + + RowListLen + + ByteListLen + + Debug + + Clone + + Default; /// Encodes the `elems` to a list in the format and also returns the length of the list. fn encode_list(elems: impl Iterator) -> (Self::List, u64); @@ -899,6 +917,13 @@ impl, I: AsRef<[RowOffset]>> RowListLen for BsatnRowList { } } +impl, I> ByteListLen for BsatnRowList { + /// Returns the uncompressed size of the list in bytes + fn num_bytes(&self) -> usize { + self.rows_data.as_ref().len() + } +} + impl BsatnRowList { /// Returns the element at `index` in the list. pub fn get(&self, index: usize) -> Option { diff --git a/crates/client-api/src/lib.rs b/crates/client-api/src/lib.rs index c7e6c626d99..7ee2d86e6b6 100644 --- a/crates/client-api/src/lib.rs +++ b/crates/client-api/src/lib.rs @@ -6,13 +6,11 @@ use http::StatusCode; use spacetimedb::client::ClientActorIndex; use spacetimedb::energy::{EnergyBalance, EnergyQuanta}; -use spacetimedb::execution_context::Workload; use spacetimedb::host::{HostController, ModuleHost, NoSuchModule, UpdateDatabaseResult}; use spacetimedb::identity::{AuthCtx, Identity}; use spacetimedb::json::client_api::StmtResultJson; use spacetimedb::messages::control_db::{Database, HostType, Node, Replica}; use spacetimedb::sql; -use spacetimedb::sql::execute::translate_col; use spacetimedb_client_api_messages::name::{DomainName, InsertDomainResult, RegisterTldResult, Tld}; use spacetimedb_lib::ProductTypeElement; use spacetimedb_paths::server::ModuleLogsDir; @@ -82,37 +80,34 @@ impl Host { self.replica_id, move |db| -> axum::response::Result<_, (StatusCode, String)> { tracing::info!(sql = body); - let results = - sql::execute::run(db, &body, auth, Some(&module_host.info().subscriptions)).map_err(|e| { - log::warn!("{}", e); - if let Some(auth_err) = e.get_auth_error() { - (StatusCode::UNAUTHORIZED, auth_err.to_string()) - } else { - (StatusCode::BAD_REQUEST, e.to_string()) - } - })?; - - let json = db.with_read_only(Workload::Sql, |tx| { - results - .into_iter() - .map(|result| { - let rows = result.data; - let schema = result - .head - .fields - .iter() - .map(|x| { - let ty = x.algebraic_type.clone(); - let name = translate_col(tx, x.field); - ProductTypeElement::new(ty, name) - }) - .collect(); - StmtResultJson { schema, rows } - }) - .collect::>() - }); - - Ok(json) + + // We need a header for query results + let mut header = vec![]; + + let rows = sql::execute::run( + // Returns an empty result set for mutations + db, + &body, + auth, + Some(&module_host.info().subscriptions), + &mut header, + ) + .map_err(|e| { + log::warn!("{}", e); + if let Some(auth_err) = e.get_auth_error() { + (StatusCode::UNAUTHORIZED, auth_err.to_string()) + } else { + (StatusCode::BAD_REQUEST, e.to_string()) + } + })?; + + // Turn the header into a `ProductType` + let schema = header + .into_iter() + .map(|(col_name, col_type)| ProductTypeElement::new(col_type, Some(col_name))) + .collect(); + + Ok(vec![StmtResultJson { schema, rows }]) }, ) .await diff --git a/crates/core/proptest-regressions/db/datastore/locking_tx_datastore/delete_table.txt b/crates/core/proptest-regressions/db/datastore/locking_tx_datastore/delete_table.txt index 2813c49fa80..5f8236e3684 100644 --- a/crates/core/proptest-regressions/db/datastore/locking_tx_datastore/delete_table.txt +++ b/crates/core/proptest-regressions/db/datastore/locking_tx_datastore/delete_table.txt @@ -5,3 +5,4 @@ # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. cc 9d74a36309d35ba5b09bd33cbc24f72f95a498152b469caeeeb8c2adf712974c # shrinks to (size, [ptr_a, ptr_b]) = (Size(2), [RowPointer(r: 0, pi: 0, po: 0, so: 1), RowPointer(r: 0, pi: 0, po: 0, so: 1)]) +cc daf4d4880192b4f6d95f928ed49d80465b48ff4fde4ae31aea5cfea3996cdbc6 # shrinks to minimal failing input: (size, [ptr_a, ptr_b]) = (Size(23), [RowPointer(r: 0, pi: 75, po: 828, so: 1), RowPointer(r: 0, pi: 75, po: 828, so: 1)]) diff --git a/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs b/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs index c1913d0aa6f..dc60779c269 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs @@ -393,7 +393,8 @@ impl CommittedState { }; let is_unique = unique_constraints.contains(&(table_id, (&columns).into())); let index = table.new_index(columns.clone(), is_unique)?; - table.insert_index(blob_store, index_id, index); + // SAFETY: `index` was derived from `table`. + unsafe { table.insert_index(blob_store, index_id, index) }; self.index_id_map.insert(index_id, table_id); } Ok(()) @@ -609,7 +610,11 @@ impl CommittedState { for (cols, mut index) in tx_table.indexes { if !commit_table.indexes.contains_key(&cols) { index.clear(); - commit_table.insert_index(commit_blob_store, cols, index); + // SAFETY: `tx_table` is derived from `commit_table`, + // so they have the same row type. + // This entails that all indices in `tx_table` + // were constructed with the same row type/layout as `commit_table`. + unsafe { commit_table.insert_index(commit_blob_store, cols, index) }; } } @@ -679,6 +684,39 @@ impl CommittedState { let blob_store = &mut self.blob_store; (table, blob_store) } + + pub(super) fn report_data_size(&self, database_identity: Identity) { + use crate::db::db_metrics::data_size::DATA_SIZE_METRICS; + + for (table_id, table) in &self.tables { + let table_name = &table.schema.table_name; + DATA_SIZE_METRICS + .data_size_table_num_rows + .with_label_values(&database_identity, &table_id.0, table_name) + .set(table.num_rows() as _); + DATA_SIZE_METRICS + .data_size_table_bytes_used_by_rows + .with_label_values(&database_identity, &table_id.0, table_name) + .set(table.bytes_used_by_rows() as _); + DATA_SIZE_METRICS + .data_size_table_num_rows_in_indexes + .with_label_values(&database_identity, &table_id.0, table_name) + .set(table.num_rows_in_indexes() as _); + DATA_SIZE_METRICS + .data_size_table_bytes_used_by_index_keys + .with_label_values(&database_identity, &table_id.0, table_name) + .set(table.bytes_used_by_index_keys() as _); + } + + DATA_SIZE_METRICS + .data_size_blob_store_num_blobs + .with_label_values(&database_identity) + .set(self.blob_store.num_blobs() as _); + DATA_SIZE_METRICS + .data_size_blob_store_bytes_used_by_blobs + .with_label_values(&database_identity) + .set(self.blob_store.bytes_used_by_blobs() as _); + } } pub struct CommittedIndexIterWithDeletedMutTx<'a> { diff --git a/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs b/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs index 5d19f5f7f07..8581288959f 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs @@ -6,11 +6,14 @@ use super::{ tx::TxId, tx_state::TxState, }; -use crate::db::datastore::{ - locking_tx_datastore::state_view::{IterByColRangeMutTx, IterMutTx, IterTx}, - traits::{InsertFlags, UpdateFlags}, -}; use crate::execution_context::Workload; +use crate::{ + db::datastore::{ + locking_tx_datastore::state_view::{IterByColRangeMutTx, IterMutTx, IterTx}, + traits::{InsertFlags, UpdateFlags}, + }, + subscription::record_exec_metrics, +}; use crate::{ db::{ datastore::{ @@ -34,7 +37,7 @@ use core::{cell::RefCell, ops::RangeBounds}; use parking_lot::{Mutex, RwLock}; use spacetimedb_commitlog::payload::{txdata, Txdata}; use spacetimedb_durability::TxOffset; -use spacetimedb_lib::db::auth::StAccess; +use spacetimedb_lib::{db::auth::StAccess, metrics::ExecutionMetrics}; use spacetimedb_lib::{Address, Identity}; use spacetimedb_paths::server::SnapshotDirPath; use spacetimedb_primitives::{ColList, ConstraintId, IndexId, SequenceId, TableId}; @@ -315,11 +318,13 @@ impl Tx for Locking { let committed_state_shared_lock = self.committed_state.read_arc(); let lock_wait_time = timer.elapsed(); let ctx = ExecutionContext::with_workload(self.database_identity, workload); + let metrics = ExecutionMetrics::default(); Self::Tx { committed_state_shared_lock, lock_wait_time, timer, ctx, + metrics, } } @@ -629,13 +634,14 @@ impl MutTxDatastore for Locking { } /// This utility is responsible for recording all transaction metrics. -pub(super) fn record_metrics( +pub(super) fn record_tx_metrics( ctx: &ExecutionContext, tx_timer: Instant, lock_wait_time: Duration, committed: bool, tx_data: Option<&TxData>, committed_state: Option<&CommittedState>, + metrics: ExecutionMetrics, ) { let workload = &ctx.workload(); let db = &ctx.database_identity(); @@ -662,6 +668,8 @@ pub(super) fn record_metrics( .with_label_values(workload, db, reducer) .observe(elapsed_time); + record_exec_metrics(workload, db, metrics); + /// Update table rows and table size gauges, /// and sets them to zero if no table is present. fn update_table_gauges(db: &Identity, table_id: &TableId, table_name: &str, table: Option<&Table>) { @@ -698,6 +706,12 @@ pub(super) fn record_metrics( .inc_by(deletes.len() as u64); } } + + if let Some(committed_state) = committed_state { + // TODO(cleanliness,bikeshedding): Consider inlining `report_data_size` here, + // or moving the above metric writes into it, for consistency of organization. + committed_state.report_data_size(*db); + } } impl MutTx for Locking { @@ -719,6 +733,7 @@ impl MutTx for Locking { lock_wait_time, timer, ctx, + metrics: ExecutionMetrics::default(), } } @@ -995,10 +1010,10 @@ mod tests { use crate::db::datastore::system_tables::{ system_tables, StColumnRow, StConstraintData, StConstraintFields, StConstraintRow, StIndexAlgorithm, StIndexFields, StIndexRow, StRowLevelSecurityFields, StScheduledFields, StSequenceFields, StSequenceRow, - StTableRow, StVarFields, StVarValue, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, ST_CONSTRAINT_ID, - ST_CONSTRAINT_NAME, ST_INDEX_ID, ST_INDEX_NAME, ST_MODULE_NAME, ST_RESERVED_SEQUENCE_RANGE, - ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_NAME, ST_SCHEDULED_ID, ST_SCHEDULED_NAME, ST_SEQUENCE_ID, - ST_SEQUENCE_NAME, ST_TABLE_NAME, ST_VAR_ID, ST_VAR_NAME, + StTableRow, StVarFields, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_NAME, + ST_INDEX_ID, ST_INDEX_NAME, ST_MODULE_NAME, ST_RESERVED_SEQUENCE_RANGE, ST_ROW_LEVEL_SECURITY_ID, + ST_ROW_LEVEL_SECURITY_NAME, ST_SCHEDULED_ID, ST_SCHEDULED_NAME, ST_SEQUENCE_ID, ST_SEQUENCE_NAME, + ST_TABLE_NAME, ST_VAR_ID, ST_VAR_NAME, }; use crate::db::datastore::traits::{IsolationLevel, MutTx}; use crate::db::datastore::Result; @@ -1009,6 +1024,7 @@ mod tests { use pretty_assertions::{assert_eq, assert_matches}; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::error::ResultTest; + use spacetimedb_lib::st_var::StVarValue; use spacetimedb_lib::{resolved_type_via_v9, ScheduleAt}; use spacetimedb_primitives::{col_list, ColId, ScheduleId}; use spacetimedb_sats::algebraic_value::ser::value_serialize; diff --git a/crates/core/src/db/datastore/locking_tx_datastore/delete_table.rs b/crates/core/src/db/datastore/locking_tx_datastore/delete_table.rs index 3bcc7a8d202..b2a241d76ac 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/delete_table.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/delete_table.rs @@ -239,6 +239,8 @@ mod test { proptest! { #[test] fn insertion_entails_contained((size, [ptr_a, ptr_b]) in gen_size_and_ptrs()) { + prop_assume!(ptr_a != ptr_b); + let mut dt = TestDT::new(size); // Initially we have nothing. @@ -266,7 +268,10 @@ mod test { #[test] fn deleting_non_existent_does_nothing((size, [ptr_a, ptr_b]) in gen_size_and_ptrs()) { + prop_assume!(ptr_a != ptr_b); + let mut dt = TestDT::new(size); + prop_assert!(!dt.remove(ptr_b)); prop_assert!(dt.insert(ptr_a)); prop_assert!(!dt.remove(ptr_b)); diff --git a/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs b/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs index 4b39fe43cc0..70373cddb25 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs @@ -1,6 +1,6 @@ use super::{ committed_state::CommittedState, - datastore::{record_metrics, Result}, + datastore::{record_tx_metrics, Result}, delete_table::DeleteTable, sequence::{Sequence, SequencesState}, state_view::{IndexSeekIterIdMutTx, ScanIterByColRangeMutTx, StateView}, @@ -33,8 +33,12 @@ use core::cell::RefCell; use core::ops::RangeBounds; use core::{iter, ops::Bound}; use smallvec::SmallVec; -use spacetimedb_lib::db::raw_def::v9::RawSql; -use spacetimedb_lib::db::{auth::StAccess, raw_def::SEQUENCE_ALLOCATION_STEP}; +use spacetimedb_execution::{dml::MutDatastore, Datastore, DeltaStore}; +use spacetimedb_lib::{db::raw_def::v9::RawSql, metrics::ExecutionMetrics}; +use spacetimedb_lib::{ + db::{auth::StAccess, raw_def::SEQUENCE_ALLOCATION_STEP}, + query::Delta, +}; use spacetimedb_primitives::{ColId, ColList, ColSet, ConstraintId, IndexId, ScheduleId, SequenceId, TableId}; use spacetimedb_sats::{ bsatn::{self, to_writer, DecodeError, Deserializer}, @@ -70,6 +74,49 @@ pub struct MutTxId { pub(super) lock_wait_time: Duration, pub(crate) timer: Instant, pub(crate) ctx: ExecutionContext, + pub(crate) metrics: ExecutionMetrics, +} + +impl Datastore for MutTxId { + fn blob_store(&self) -> &dyn BlobStore { + &self.committed_state_write_lock.blob_store + } + + fn table(&self, table_id: TableId) -> Option<&Table> { + self.committed_state_write_lock.get_table(table_id) + } +} + +/// Note, deltas are evaluated using read-only transactions, not mutable ones. +/// Nevertheless this contract is still required for query evaluation. +impl DeltaStore for MutTxId { + fn has_inserts(&self, _: TableId) -> Option { + None + } + + fn has_deletes(&self, _: TableId) -> Option { + None + } + + fn inserts_for_table(&self, _: TableId) -> Option> { + None + } + + fn deletes_for_table(&self, _: TableId) -> Option> { + None + } +} + +impl MutDatastore for MutTxId { + fn insert_product_value(&mut self, table_id: TableId, row: &ProductValue) -> anyhow::Result<()> { + self.insert_via_serialize_bsatn(table_id, row)?; + Ok(()) + } + + fn delete_product_value(&mut self, table_id: TableId, row: &ProductValue) -> anyhow::Result<()> { + self.delete_by_row_value(table_id, row)?; + Ok(()) + } } impl MutTxId { @@ -403,7 +450,11 @@ impl MutTxId { // the existing rows having the same value for some column(s). let mut insert_index = table.new_index(columns.clone(), is_unique)?; let mut build_from_rows = |table: &Table, bs: &dyn BlobStore| -> Result<()> { - if let Some(violation) = insert_index.build_from_rows(table.scan_rows(bs))? { + let rows = table.scan_rows(bs); + // SAFETY: (1) `insert_index` was derived from `table` + // which in turn was derived from `commit_table`. + let violation = unsafe { insert_index.build_from_rows(rows) }; + if let Err(violation) = violation { let violation = table .get_row_ref(bs, violation) .expect("row came from scanning the table") @@ -434,7 +485,8 @@ impl MutTxId { columns ); - table.add_index(index_id, insert_index); + // SAFETY: same as (1). + unsafe { table.add_index(index_id, insert_index) }; // Associate `index_id -> table_id` for fast lookup. idx_map.insert(index_id, table_id); @@ -1052,13 +1104,14 @@ impl MutTxId { let tx_data = committed_state_write_lock.merge(tx_state, &self.ctx); // Record metrics for the transaction at the very end, // right before we drop and release the lock. - record_metrics( + record_tx_metrics( &self.ctx, self.timer, self.lock_wait_time, true, Some(&tx_data), Some(&committed_state_write_lock), + self.metrics, ); tx_data } @@ -1072,13 +1125,14 @@ impl MutTxId { let tx_data = committed_state_write_lock.merge(tx_state, &self.ctx); // Record metrics for the transaction at the very end, // right before we drop and release the lock. - record_metrics( + record_tx_metrics( &self.ctx, self.timer, self.lock_wait_time, true, Some(&tx_data), Some(&committed_state_write_lock), + self.metrics, ); // Update the workload type of the execution context self.ctx.workload = workload.into(); @@ -1087,6 +1141,7 @@ impl MutTxId { lock_wait_time: Duration::ZERO, timer: Instant::now(), ctx: self.ctx, + metrics: ExecutionMetrics::default(), }; (tx_data, tx) } @@ -1094,13 +1149,29 @@ impl MutTxId { pub fn rollback(self) { // Record metrics for the transaction at the very end, // right before we drop and release the lock. - record_metrics(&self.ctx, self.timer, self.lock_wait_time, false, None, None); + record_tx_metrics( + &self.ctx, + self.timer, + self.lock_wait_time, + false, + None, + None, + self.metrics, + ); } pub fn rollback_downgrade(mut self, workload: Workload) -> TxId { // Record metrics for the transaction at the very end, // right before we drop and release the lock. - record_metrics(&self.ctx, self.timer, self.lock_wait_time, false, None, None); + record_tx_metrics( + &self.ctx, + self.timer, + self.lock_wait_time, + false, + None, + None, + self.metrics, + ); // Update the workload type of the execution context self.ctx.workload = workload.into(); TxId { @@ -1108,6 +1179,7 @@ impl MutTxId { lock_wait_time: Duration::ZERO, timer: Instant::now(), ctx: self.ctx, + metrics: ExecutionMetrics::default(), } } } @@ -1341,11 +1413,11 @@ impl MutTxId { // but it could do so wrt., `commit_table`, // assuming the conflicting row hasn't been deleted since. // Ensure that it doesn't, or roll back the insertion. - if let Err(e) = commit_table.check_unique_constraints( - tx_row_ref, - |ixs| ixs, - |commit_ptr| delete_table.contains(commit_ptr), - ) { + let is_deleted = |commit_ptr| delete_table.contains(commit_ptr); + // SAFETY: `commit_table.row_layout() == tx_row_ref.row_layout()` holds + // as the `tx_table` is derived from `commit_table`. + let res = unsafe { commit_table.check_unique_constraints(tx_row_ref, |ixs| ixs, is_deleted) }; + if let Err(e) = res { // There was a constraint violation, so undo the insertion. tx_table.delete(tx_blob_store, tx_row_ptr, |_| {}); return Err(IndexError::from(e).into()); @@ -1460,24 +1532,27 @@ impl MutTxId { new_row: RowRef<'_>, old_ptr: RowPointer, ) -> Result<()> { - commit_table - .check_unique_constraints( + let is_deleted = |commit_ptr| commit_ptr == old_ptr || del_table.contains(commit_ptr); + // SAFETY: `commit_table.row_layout() == new_row.row_layout()` holds + // as the `tx_table` is derived from `commit_table`. + let res = unsafe { + commit_table.check_unique_constraints( new_row, // Don't check this index since we'll do a 1-1 old/new replacement. |ixs| ixs.filter(|(&id, _)| id != ignore_index_id), - |commit_ptr| commit_ptr == old_ptr || del_table.contains(commit_ptr), + is_deleted, ) - .map_err(IndexError::from) - .map_err(Into::into) + }; + res.map_err(IndexError::from).map_err(Into::into) } /// Projects the new row to the index to find the old row. #[inline] fn find_old_row(new_row: RowRef<'_>, index: TableAndIndex<'_>) -> (Option, AlgebraicValue) { let index = index.index(); // Project the row to the index's columns/type. - let needle = new_row - .project(&index.indexed_columns) - .expect("projecting a table row to one of the table's indices should never fail"); + // SAFETY: `new_row` belongs to the same table as `index`, + // so all `index.indexed_columns` will be in-bounds of the row layout. + let needle = unsafe { new_row.project_unchecked(&index.indexed_columns) }; // Find the old row. (index.seek(&needle).next(), needle) } diff --git a/crates/core/src/db/datastore/locking_tx_datastore/tx.rs b/crates/core/src/db/datastore/locking_tx_datastore/tx.rs index ae155b0047f..d1c0c7636ad 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/tx.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/tx.rs @@ -1,4 +1,4 @@ -use super::datastore::record_metrics; +use super::datastore::record_tx_metrics; use super::{ committed_state::CommittedState, datastore::Result, @@ -8,6 +8,7 @@ use super::{ use crate::db::datastore::locking_tx_datastore::state_view::IterTx; use crate::execution_context::ExecutionContext; use spacetimedb_execution::Datastore; +use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_primitives::{ColList, TableId}; use spacetimedb_sats::AlgebraicValue; use spacetimedb_schema::schema::TableSchema; @@ -25,6 +26,7 @@ pub struct TxId { pub(super) lock_wait_time: Duration, pub(super) timer: Instant, pub(crate) ctx: ExecutionContext, + pub(crate) metrics: ExecutionMetrics, } impl Datastore for TxId { @@ -86,7 +88,15 @@ impl StateView for TxId { impl TxId { pub(super) fn release(self) { - record_metrics(&self.ctx, self.timer, self.lock_wait_time, true, None, None); + record_tx_metrics( + &self.ctx, + self.timer, + self.lock_wait_time, + true, + None, + None, + self.metrics, + ); } /// The Number of Distinct Values (NDV) for a column or list of columns, diff --git a/crates/core/src/db/datastore/system_tables.rs b/crates/core/src/db/datastore/system_tables.rs index d5e29fbc41f..705a6e1e490 100644 --- a/crates/core/src/db/datastore/system_tables.rs +++ b/crates/core/src/db/datastore/system_tables.rs @@ -13,21 +13,19 @@ use crate::db::relational_db::RelationalDB; use crate::error::DBError; -use derive_more::From; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::db::raw_def::v9::{RawIndexAlgorithm, RawSql}; use spacetimedb_lib::db::raw_def::*; use spacetimedb_lib::de::{Deserialize, DeserializeOwned, Error}; use spacetimedb_lib::ser::Serialize; +use spacetimedb_lib::st_var::StVarValue; use spacetimedb_lib::{Address, Identity, ProductValue, SpacetimeType}; use spacetimedb_primitives::*; use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; use spacetimedb_sats::algebraic_value::ser::value_serialize; use spacetimedb_sats::hash::Hash; use spacetimedb_sats::product_value::InvalidFieldError; -use spacetimedb_sats::{ - impl_deserialize, impl_serialize, impl_st, u256, AlgebraicType, AlgebraicValue, ArrayValue, SumValue, -}; +use spacetimedb_sats::{impl_deserialize, impl_serialize, impl_st, u256, AlgebraicType, AlgebraicValue, ArrayValue}; use spacetimedb_schema::def::{BTreeAlgorithm, ConstraintData, IndexAlgorithm, ModuleDef, UniqueConstraintData}; use spacetimedb_schema::schema::{ ColumnSchema, ConstraintSchema, IndexSchema, RowLevelSecuritySchema, ScheduleSchema, Schema, SequenceSchema, @@ -1098,146 +1096,6 @@ impl StVarName { } } -/// The value of a system variable in `st_var` -#[derive(Debug, Clone, From, SpacetimeType)] -#[sats(crate = spacetimedb_lib)] -pub enum StVarValue { - Bool(bool), - I8(i8), - U8(u8), - I16(i16), - U16(u16), - I32(i32), - U32(u32), - I64(i64), - U64(u64), - I128(i128), - U128(u128), - // No support for u/i256 added here as it seems unlikely to be useful. - F32(f32), - F64(f64), - String(Box), -} - -impl StVarValue { - pub fn try_from_primitive(value: AlgebraicValue) -> Result { - match value { - AlgebraicValue::Bool(v) => Ok(StVarValue::Bool(v)), - AlgebraicValue::I8(v) => Ok(StVarValue::I8(v)), - AlgebraicValue::U8(v) => Ok(StVarValue::U8(v)), - AlgebraicValue::I16(v) => Ok(StVarValue::I16(v)), - AlgebraicValue::U16(v) => Ok(StVarValue::U16(v)), - AlgebraicValue::I32(v) => Ok(StVarValue::I32(v)), - AlgebraicValue::U32(v) => Ok(StVarValue::U32(v)), - AlgebraicValue::I64(v) => Ok(StVarValue::I64(v)), - AlgebraicValue::U64(v) => Ok(StVarValue::U64(v)), - AlgebraicValue::I128(v) => Ok(StVarValue::I128(v.0)), - AlgebraicValue::U128(v) => Ok(StVarValue::U128(v.0)), - AlgebraicValue::F32(v) => Ok(StVarValue::F32(v.into_inner())), - AlgebraicValue::F64(v) => Ok(StVarValue::F64(v.into_inner())), - AlgebraicValue::String(v) => Ok(StVarValue::String(v)), - _ => Err(value), - } - } - - pub fn try_from_sum(value: AlgebraicValue) -> Result { - value.into_sum()?.try_into() - } -} - -impl TryFrom for StVarValue { - type Error = AlgebraicValue; - - fn try_from(sum: SumValue) -> Result { - match sum.tag { - 0 => Ok(StVarValue::Bool(sum.value.into_bool()?)), - 1 => Ok(StVarValue::I8(sum.value.into_i8()?)), - 2 => Ok(StVarValue::U8(sum.value.into_u8()?)), - 3 => Ok(StVarValue::I16(sum.value.into_i16()?)), - 4 => Ok(StVarValue::U16(sum.value.into_u16()?)), - 5 => Ok(StVarValue::I32(sum.value.into_i32()?)), - 6 => Ok(StVarValue::U32(sum.value.into_u32()?)), - 7 => Ok(StVarValue::I64(sum.value.into_i64()?)), - 8 => Ok(StVarValue::U64(sum.value.into_u64()?)), - 9 => Ok(StVarValue::I128(sum.value.into_i128()?.0)), - 10 => Ok(StVarValue::U128(sum.value.into_u128()?.0)), - 11 => Ok(StVarValue::F32(sum.value.into_f32()?.into_inner())), - 12 => Ok(StVarValue::F64(sum.value.into_f64()?.into_inner())), - 13 => Ok(StVarValue::String(sum.value.into_string()?)), - _ => Err(*sum.value), - } - } -} - -impl From for AlgebraicValue { - fn from(value: StVarValue) -> Self { - AlgebraicValue::Sum(value.into()) - } -} - -impl From for SumValue { - fn from(value: StVarValue) -> Self { - match value { - StVarValue::Bool(v) => SumValue { - tag: 0, - value: Box::new(AlgebraicValue::Bool(v)), - }, - StVarValue::I8(v) => SumValue { - tag: 1, - value: Box::new(AlgebraicValue::I8(v)), - }, - StVarValue::U8(v) => SumValue { - tag: 2, - value: Box::new(AlgebraicValue::U8(v)), - }, - StVarValue::I16(v) => SumValue { - tag: 3, - value: Box::new(AlgebraicValue::I16(v)), - }, - StVarValue::U16(v) => SumValue { - tag: 4, - value: Box::new(AlgebraicValue::U16(v)), - }, - StVarValue::I32(v) => SumValue { - tag: 5, - value: Box::new(AlgebraicValue::I32(v)), - }, - StVarValue::U32(v) => SumValue { - tag: 6, - value: Box::new(AlgebraicValue::U32(v)), - }, - StVarValue::I64(v) => SumValue { - tag: 7, - value: Box::new(AlgebraicValue::I64(v)), - }, - StVarValue::U64(v) => SumValue { - tag: 8, - value: Box::new(AlgebraicValue::U64(v)), - }, - StVarValue::I128(v) => SumValue { - tag: 9, - value: Box::new(AlgebraicValue::I128(v.into())), - }, - StVarValue::U128(v) => SumValue { - tag: 10, - value: Box::new(AlgebraicValue::U128(v.into())), - }, - StVarValue::F32(v) => SumValue { - tag: 11, - value: Box::new(AlgebraicValue::F32(v.into())), - }, - StVarValue::F64(v) => SumValue { - tag: 12, - value: Box::new(AlgebraicValue::F64(v.into())), - }, - StVarValue::String(v) => SumValue { - tag: 13, - value: Box::new(AlgebraicValue::String(v)), - }, - } - } -} - impl TryFrom> for StVarRow { type Error = DBError; diff --git a/crates/core/src/db/db_metrics/data_size.rs b/crates/core/src/db/db_metrics/data_size.rs new file mode 100644 index 00000000000..09430010a32 --- /dev/null +++ b/crates/core/src/db/db_metrics/data_size.rs @@ -0,0 +1,42 @@ +use once_cell::sync::Lazy; +use prometheus::IntGaugeVec; +use spacetimedb_lib::Identity; +use spacetimedb_metrics::metrics_group; + +metrics_group!( + #[non_exhaustive] + pub struct DbDataSize { + #[name = spacetime_data_size_table_num_rows] + #[help = "The number of rows in a table"] + #[labels(db: Identity, table_id: u32, table_name: str)] + pub data_size_table_num_rows: IntGaugeVec, + + #[name = spacetime_data_size_bytes_used_by_rows] + #[help = "The number of bytes used by rows in pages in a table"] + #[labels(db: Identity, table_id: u32, table_name: str)] + pub data_size_table_bytes_used_by_rows: IntGaugeVec, + + #[name = spacetime_data_size_table_num_rows_in_indexes] + #[help = "The number of rows stored in indexes in a table"] + // TODO: Consider partitioning by index ID or index name. + #[labels(db: Identity, table_id: u32, table_name: str)] + pub data_size_table_num_rows_in_indexes: IntGaugeVec, + + #[name = spacetime_data_size_table_bytes_used_by_index_keys] + #[help = "The number of bytes used by keys stored in indexes in a table"] + #[labels(db: Identity, table_id: u32, table_name: str)] + pub data_size_table_bytes_used_by_index_keys: IntGaugeVec, + + #[name = spacetime_data_size_blob_store_num_blobs] + #[help = "The number of large blobs stored in a database's blob store"] + #[labels(db: Identity)] + pub data_size_blob_store_num_blobs: IntGaugeVec, + + #[name = spacetime_data_size_blob_store_bytes_used_by_blobs] + #[help = "The number of bytes used by large blobs stored in a database's blob store"] + #[labels(db: Identity)] + pub data_size_blob_store_bytes_used_by_blobs: IntGaugeVec, + } +); + +pub static DATA_SIZE_METRICS: Lazy = Lazy::new(DbDataSize::new); diff --git a/crates/core/src/db/db_metrics/mod.rs b/crates/core/src/db/db_metrics/mod.rs index 8dd965fa200..9c538cbb7f6 100644 --- a/crates/core/src/db/db_metrics/mod.rs +++ b/crates/core/src/db/db_metrics/mod.rs @@ -5,6 +5,8 @@ use spacetimedb_lib::Identity; use spacetimedb_metrics::metrics_group; use spacetimedb_primitives::TableId; +pub mod data_size; + metrics_group!( #[non_exhaustive] pub struct DbMetrics { @@ -23,19 +25,24 @@ metrics_group!( #[labels(txn_type: WorkloadType, db: Identity, reducer_or_query: str, table_id: u32, table_name: str)] pub rdb_num_rows_deleted: IntCounterVec, - #[name = spacetime_num_rows_fetched_total] - #[help = "The cumulative number of rows fetched from a table"] - #[labels(txn_type: WorkloadType, db: Identity, reducer_or_query: str, table_id: u32, table_name: str)] - pub rdb_num_rows_fetched: IntCounterVec, + #[name = spacetime_num_rows_scanned_total] + #[help = "The cumulative number of rows scanned from the database"] + #[labels(txn_type: WorkloadType, db: Identity)] + pub rdb_num_rows_scanned: IntCounterVec, - #[name = spacetime_num_index_keys_scanned_total] - #[help = "The cumulative number of keys scanned from an index"] - #[labels(txn_type: WorkloadType, db: Identity, reducer_or_query: str, table_id: u32, table_name: str)] - pub rdb_num_keys_scanned: IntCounterVec, + #[name = spacetime_num_bytes_scanned_total] + #[help = "The cumulative number of bytes scanned from the database"] + #[labels(txn_type: WorkloadType, db: Identity)] + pub rdb_num_bytes_scanned: IntCounterVec, + + #[name = spacetime_num_bytes_written_total] + #[help = "The cumulative number of bytes written to the database"] + #[labels(txn_type: WorkloadType, db: Identity)] + pub rdb_num_bytes_written: IntCounterVec, #[name = spacetime_num_index_seeks_total] #[help = "The cumulative number of index seeks"] - #[labels(txn_type: WorkloadType, db: Identity, reducer_or_query: str, table_id: u32, table_name: str)] + #[labels(txn_type: WorkloadType, db: Identity)] pub rdb_num_index_seeks: IntCounterVec, #[name = spacetime_num_txns_total] diff --git a/crates/core/src/host/instance_env.rs b/crates/core/src/host/instance_env.rs index 4078d00a26b..ab05ba9d71a 100644 --- a/crates/core/src/host/instance_env.rs +++ b/crates/core/src/host/instance_env.rs @@ -14,7 +14,7 @@ use spacetimedb_sats::{ AlgebraicValue, ProductValue, }; use spacetimedb_table::indexes::RowPointer; -use spacetimedb_table::table::{RowRef, UniqueConstraintViolation}; +use spacetimedb_table::table::RowRef; use std::ops::DerefMut; use std::sync::Arc; @@ -80,7 +80,12 @@ impl ChunkedWriter { self.chunks } - pub fn collect_iter(pool: &mut ChunkPool, iter: impl Iterator) -> Vec> { + pub fn collect_iter( + pool: &mut ChunkPool, + iter: impl Iterator, + rows_scanned: &mut usize, + bytes_scanned: &mut usize, + ) -> Vec> { let mut chunked_writer = Self::default(); // Consume the iterator, serializing each `item`, // while allowing a chunk to be created at boundaries. @@ -89,9 +94,16 @@ impl ChunkedWriter { item.to_bsatn_extend(&mut chunked_writer.curr).unwrap(); // Flush at item boundaries. chunked_writer.flush(pool); + // Update rows scanned + *rows_scanned += 1; } - chunked_writer.into_chunks() + let chunks = chunked_writer.into_chunks(); + + // Update (BSATN) bytes scanned + *bytes_scanned += chunks.iter().map(|chunk| chunk.len()).sum::(); + + chunks } } @@ -148,22 +160,30 @@ impl InstanceEnv { let row_len = Self::project_cols_bsatn(buffer, gen_cols, row_ref); (row_len, row_ref.pointer(), insert_flags) }) - .inspect_err(|e| match e { - DBError::Index(IndexError::UniqueConstraintViolation(UniqueConstraintViolation { .. })) => {} - _ => { - let res = stdb.table_name_from_id_mut(tx, table_id); - if let Ok(Some(table_name)) = res { - log::debug!("insert(table: {table_name}, table_id: {table_id}): {e}") - } else { - log::debug!("insert(table_id: {table_id}): {e}") + .inspect_err( + #[cold] + #[inline(never)] + |e| match e { + DBError::Index(IndexError::UniqueConstraintViolation(_)) => {} + _ => { + let res = stdb.table_name_from_id_mut(tx, table_id); + if let Ok(Some(table_name)) = res { + log::debug!("insert(table: {table_name}, table_id: {table_id}): {e}") + } else { + log::debug!("insert(table_id: {table_id}): {e}") + } } - } - })?; + }, + )?; if insert_flags.is_scheduler_table { self.schedule_row(stdb, tx, table_id, row_ptr)?; } + // Note, we update the metric for bytes written after the insert. + // This is to capture auto-inc columns. + tx.metrics.bytes_written += buffer.len(); + Ok(row_len) } @@ -202,17 +222,21 @@ impl InstanceEnv { let row_len = Self::project_cols_bsatn(buffer, gen_cols, row_ref); (row_len, row_ref.pointer(), update_flags) }) - .inspect_err(|e| match e { - DBError::Index(IndexError::UniqueConstraintViolation(UniqueConstraintViolation { .. })) => {} - _ => { - let res = stdb.table_name_from_id_mut(tx, table_id); - if let Ok(Some(table_name)) = res { - log::debug!("update(table: {table_name}, table_id: {table_id}, index_id: {index_id}): {e}") - } else { - log::debug!("update(table_id: {table_id}, index_id: {index_id}): {e}") + .inspect_err( + #[cold] + #[inline(never)] + |e| match e { + DBError::Index(IndexError::UniqueConstraintViolation(_)) => {} + _ => { + let res = stdb.table_name_from_id_mut(tx, table_id); + if let Ok(Some(table_name)) = res { + log::debug!("update(table: {table_name}, table_id: {table_id}, index_id: {index_id}): {e}") + } else { + log::debug!("update(table_id: {table_id}, index_id: {index_id}): {e}") + } } - } - })?; + }, + )?; if update_flags.is_scheduler_table { self.schedule_row(stdb, tx, table_id, row_ptr)?; @@ -238,6 +262,14 @@ impl InstanceEnv { // Re. `SmallVec`, `delete_by_field` only cares about 1 element, so optimize for that. let rows_to_delete = iter.map(|row_ref| row_ref.pointer()).collect::>(); + // Note, we're deleting rows based on the result of a btree scan. + // Hence we must update our `index_seeks` and `rows_scanned` metrics. + // + // Note that we're not updating `bytes_scanned` at all, + // because we never dereference any of the returned `RowPointer`s. + tx.metrics.index_seeks += 1; + tx.metrics.rows_scanned += rows_to_delete.len(); + // Delete them and count how many we deleted. Ok(stdb.delete(tx, table_id, rows_to_delete)) } @@ -255,12 +287,20 @@ impl InstanceEnv { let stdb = &*self.replica_ctx.relational_db; let tx = &mut *self.get_tx()?; + // Track the number of bytes coming from the caller + tx.metrics.bytes_scanned += relation.len(); + // Find the row schema using it to decode a vector of product values. let row_ty = stdb.row_schema_for_table(tx, table_id)?; // `TableType::delete` cares about a single element // so in that case we can avoid the allocation by using `smallvec`. let relation = ProductValue::decode_smallvec(&row_ty, &mut &*relation).map_err(NodesError::DecodeRow)?; + // Note, we track the number of rows coming from the caller, + // regardless of whether or not we actually delete them, + // since we have to derive row ids for each one of them. + tx.metrics.rows_scanned += relation.len(); + // Delete them and return how many we deleted. Ok(stdb.delete_by_rel(tx, table_id, relation)) } @@ -315,7 +355,21 @@ impl InstanceEnv { let stdb = &*self.replica_ctx.relational_db; let tx = &mut *self.tx.get()?; - let chunks = ChunkedWriter::collect_iter(pool, stdb.iter_mut(tx, table_id)?); + // Track the number of rows and the number of bytes scanned by the iterator + let mut rows_scanned = 0; + let mut bytes_scanned = 0; + + // Scan table and serialize rows to bsatn + let chunks = ChunkedWriter::collect_iter( + pool, + stdb.iter_mut(tx, table_id)?, + &mut rows_scanned, + &mut bytes_scanned, + ); + + tx.metrics.rows_scanned += rows_scanned; + tx.metrics.bytes_scanned += bytes_scanned; + Ok(chunks) } @@ -332,8 +386,20 @@ impl InstanceEnv { let stdb = &*self.replica_ctx.relational_db; let tx = &mut *self.tx.get()?; + // Track rows and bytes scanned by the iterator + let mut rows_scanned = 0; + let mut bytes_scanned = 0; + + // Open index iterator let (_, iter) = stdb.btree_scan(tx, index_id, prefix, prefix_elems, rstart, rend)?; - let chunks = ChunkedWriter::collect_iter(pool, iter); + + // Scan the index and serialize rows to bsatn + let chunks = ChunkedWriter::collect_iter(pool, iter, &mut rows_scanned, &mut bytes_scanned); + + tx.metrics.index_seeks += 1; + tx.metrics.rows_scanned += rows_scanned; + tx.metrics.bytes_scanned += bytes_scanned; + Ok(chunks) } } @@ -365,3 +431,313 @@ impl From for NodesError { NodesError::NotInTransaction } } + +#[cfg(test)] +mod test { + use std::{ops::Bound, sync::Arc}; + + use anyhow::{anyhow, Result}; + use spacetimedb_lib::{bsatn::to_vec, AlgebraicType, AlgebraicValue, Hash, Identity, ProductValue}; + use spacetimedb_paths::{server::ModuleLogsDir, FromPathUnchecked}; + use spacetimedb_primitives::{IndexId, TableId}; + use spacetimedb_sats::product; + use tempfile::TempDir; + + use crate::{ + database_logger::DatabaseLogger, + db::{ + datastore::traits::IsolationLevel, + relational_db::{tests_utils::TestDB, RelationalDB}, + }, + execution_context::Workload, + host::Scheduler, + messages::control_db::{Database, HostType}, + replica_context::ReplicaContext, + subscription::module_subscription_actor::ModuleSubscriptions, + }; + + use super::{ChunkPool, InstanceEnv, TxSlot}; + + /// An `InstanceEnv` requires a `DatabaseLogger` + fn temp_logger() -> Result { + let temp = TempDir::new()?; + let path = ModuleLogsDir::from_path_unchecked(temp.into_path()); + let path = path.today(); + Ok(DatabaseLogger::open(path)) + } + + /// An `InstanceEnv` requires `ModuleSubscriptions` + fn subscription_actor(relational_db: Arc) -> ModuleSubscriptions { + ModuleSubscriptions::new(relational_db, Identity::ZERO) + } + + /// An `InstanceEnv` requires a `ReplicaContext`. + /// For our purposes this is just a wrapper for `RelationalDB`. + fn replica_ctx(relational_db: Arc) -> Result { + Ok(ReplicaContext { + database: Database { + id: 0, + database_identity: Identity::ZERO, + owner_identity: Identity::ZERO, + host_type: HostType::Wasm, + initial_program: Hash::ZERO, + }, + replica_id: 0, + logger: Arc::new(temp_logger()?), + subscriptions: subscription_actor(relational_db.clone()), + relational_db, + }) + } + + /// An `InstanceEnv` used for testing the database syscalls. + fn instance_env(db: Arc) -> Result { + let (scheduler, _) = Scheduler::open(db.clone()); + Ok(InstanceEnv { + replica_ctx: Arc::new(replica_ctx(db)?), + scheduler, + tx: TxSlot::default(), + }) + } + + /// An in-memory `RelationalDB` for testing. + /// It does not persist data to disk. + fn relational_db() -> Result> { + let TestDB { db, .. } = TestDB::in_memory()?; + Ok(Arc::new(db)) + } + + /// Generate a `ProductValue` for use in [create_table_with_index] + fn product_row(i: usize) -> ProductValue { + let str = i.to_string(); + let str = str.repeat(i); + let id = i as u64; + product!(id, str) + } + + /// Generate a BSATN encoded row for use in [create_table_with_index] + fn bsatn_row(i: usize) -> Result> { + Ok(to_vec(&product_row(i))?) + } + + /// Instantiate the following table: + /// + /// ```text + /// id | str + /// -- | --- + /// 1 | "1" + /// 2 | "22" + /// 3 | "333" + /// 4 | "4444" + /// 5 | "55555" + /// ``` + /// + /// with an index on `id`. + fn create_table_with_index(db: &RelationalDB) -> Result<(TableId, IndexId)> { + let table_id = db.create_table_for_test( + "t", + &[("id", AlgebraicType::U64), ("str", AlgebraicType::String)], + &[0.into()], + )?; + let index_id = db.with_read_only(Workload::ForTests, |tx| { + db.schema_for_table(tx, table_id)? + .indexes + .iter() + .find(|schema| { + schema + .index_algorithm + .columns() + .as_singleton() + .is_some_and(|col_id| col_id.idx() == 0) + }) + .map(|schema| schema.index_id) + .ok_or_else(|| anyhow!("Index not found for ColId `{}`", 0)) + })?; + db.with_auto_commit(Workload::ForTests, |tx| -> Result<_> { + for i in 1..=5 { + db.insert(tx, table_id, &bsatn_row(i)?)?; + } + Ok(()) + })?; + Ok((table_id, index_id)) + } + + #[test] + fn table_scan_metrics() -> Result<()> { + let db = relational_db()?; + let env = instance_env(db.clone())?; + + let (table_id, _) = create_table_with_index(&db)?; + + let mut tx_slot = env.tx.clone(); + + let f = || env.datastore_table_scan_bsatn_chunks(&mut ChunkPool::default(), table_id); + let tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::ForTests); + let (tx, scan_result) = tx_slot.set(tx, f); + + scan_result?; + + let bytes_scanned = (1..=5) + .map(bsatn_row) + .filter_map(|bsatn_result| bsatn_result.ok()) + .map(|bsatn| bsatn.len()) + .sum::(); + + // The only non-zero metrics should be rows and bytes scanned. + // The table has 5 rows, so we should have 5 rows scanned. + // We should also have scanned the same number of bytes that we inserted. + assert_eq!(0, tx.metrics.index_seeks); + assert_eq!(5, tx.metrics.rows_scanned); + assert_eq!(bytes_scanned, tx.metrics.bytes_scanned); + assert_eq!(0, tx.metrics.bytes_written); + assert_eq!(0, tx.metrics.bytes_sent_to_clients); + Ok(()) + } + + #[test] + fn index_scan_metrics() -> Result<()> { + let db = relational_db()?; + let env = instance_env(db.clone())?; + + let (_, index_id) = create_table_with_index(&db)?; + + let mut tx_slot = env.tx.clone(); + + // Perform two index scans + let f = || -> Result<_> { + let index_key_3 = to_vec(&Bound::Included(AlgebraicValue::U64(3)))?; + let index_key_5 = to_vec(&Bound::Included(AlgebraicValue::U64(5)))?; + env.datastore_btree_scan_bsatn_chunks( + &mut ChunkPool::default(), + index_id, + &[], + 0.into(), + &index_key_3, + &index_key_3, + )?; + env.datastore_btree_scan_bsatn_chunks( + &mut ChunkPool::default(), + index_id, + &[], + 0.into(), + &index_key_5, + &index_key_5, + )?; + Ok(()) + }; + let tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::ForTests); + let (tx, scan_result) = tx_slot.set(tx, f); + + scan_result?; + + let bytes_scanned = [3, 5] + .into_iter() + .map(bsatn_row) + .filter_map(|bsatn_result| bsatn_result.ok()) + .map(|bsatn| bsatn.len()) + .sum::(); + + // We performed two index scans to fetch rows 3 and 5 + assert_eq!(2, tx.metrics.index_seeks); + assert_eq!(2, tx.metrics.rows_scanned); + assert_eq!(bytes_scanned, tx.metrics.bytes_scanned); + assert_eq!(0, tx.metrics.bytes_written); + assert_eq!(0, tx.metrics.bytes_sent_to_clients); + Ok(()) + } + + #[test] + fn insert_metrics() -> Result<()> { + let db = relational_db()?; + let env = instance_env(db.clone())?; + + let (table_id, _) = create_table_with_index(&db)?; + + let mut tx_slot = env.tx.clone(); + + // Insert 4 new rows into `t` + let f = || -> Result<_> { + for i in 6..=9 { + let mut buffer = bsatn_row(i)?; + env.insert(table_id, &mut buffer)?; + } + Ok(()) + }; + let tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::ForTests); + let (tx, insert_result) = tx_slot.set(tx, f); + + insert_result?; + + let bytes_written = (6..=9) + .map(bsatn_row) + .filter_map(|bsatn_result| bsatn_result.ok()) + .map(|bsatn| bsatn.len()) + .sum::(); + + // The only metric affected by inserts is bytes written + assert_eq!(0, tx.metrics.index_seeks); + assert_eq!(0, tx.metrics.rows_scanned); + assert_eq!(0, tx.metrics.bytes_scanned); + assert_eq!(bytes_written, tx.metrics.bytes_written); + assert_eq!(0, tx.metrics.bytes_sent_to_clients); + Ok(()) + } + + #[test] + fn delete_by_index_metrics() -> Result<()> { + let db = relational_db()?; + let env = instance_env(db.clone())?; + + let (_, index_id) = create_table_with_index(&db)?; + + let mut tx_slot = env.tx.clone(); + + // Delete a single row via the index + let f = || -> Result<_> { + let index_key = to_vec(&Bound::Included(AlgebraicValue::U64(3)))?; + env.datastore_delete_by_btree_scan_bsatn(index_id, &[], 0.into(), &index_key, &index_key)?; + Ok(()) + }; + let tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::ForTests); + let (tx, delete_result) = tx_slot.set(tx, f); + + delete_result?; + + assert_eq!(1, tx.metrics.index_seeks); + assert_eq!(1, tx.metrics.rows_scanned); + assert_eq!(0, tx.metrics.bytes_scanned); + assert_eq!(0, tx.metrics.bytes_written); + assert_eq!(0, tx.metrics.bytes_sent_to_clients); + Ok(()) + } + + #[test] + fn delete_by_value_metrics() -> Result<()> { + let db = relational_db()?; + let env = instance_env(db.clone())?; + + let (table_id, _) = create_table_with_index(&db)?; + + let mut tx_slot = env.tx.clone(); + + let bsatn_rows = to_vec(&(3..=5).map(product_row).collect::>())?; + + // Delete 3 rows by value + let f = || -> Result<_> { + env.datastore_delete_all_by_eq_bsatn(table_id, &bsatn_rows)?; + Ok(()) + }; + let tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::ForTests); + let (tx, delete_result) = tx_slot.set(tx, f); + + delete_result?; + + let bytes_scanned = bsatn_rows.len(); + + assert_eq!(0, tx.metrics.index_seeks); + assert_eq!(3, tx.metrics.rows_scanned); + assert_eq!(bytes_scanned, tx.metrics.bytes_scanned); + assert_eq!(0, tx.metrics.bytes_written); + assert_eq!(0, tx.metrics.bytes_sent_to_clients); + Ok(()) + } +} diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 0bc5682a9b1..273898ea71d 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -7,13 +7,14 @@ use crate::db::datastore::traits::{IsolationLevel, Program, TxData}; use crate::energy::EnergyQuanta; use crate::error::DBError; use crate::estimation::estimate_rows_scanned; -use crate::execution_context::{ExecutionContext, ReducerContext, Workload}; +use crate::execution_context::{ExecutionContext, ReducerContext, Workload, WorkloadType}; use crate::hash::Hash; use crate::identity::Identity; use crate::messages::control_db::Database; use crate::replica_context::ReplicaContext; use crate::sql::ast::SchemaViewer; use crate::subscription::module_subscription_actor::ModuleSubscriptions; +use crate::subscription::record_exec_metrics; use crate::subscription::tx::DeltaTx; use crate::util::lending_pool::{Closed, LendingPool, LentResource, PoolClosed}; use crate::vm::check_row_limit; @@ -26,7 +27,7 @@ use indexmap::IndexSet; use itertools::Itertools; use smallvec::SmallVec; use spacetimedb_client_api_messages::timestamp::Timestamp; -use spacetimedb_client_api_messages::websocket::{Compression, OneOffTable, QueryUpdate, WebsocketFormat}; +use spacetimedb_client_api_messages::websocket::{ByteListLen, Compression, OneOffTable, QueryUpdate, WebsocketFormat}; use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_lib::db::raw_def::v9::Lifecycle; @@ -128,13 +129,14 @@ impl UpdatesRelValue<'_> { !(self.deletes.is_empty() && self.inserts.is_empty()) } - pub fn encode(&self, compression: Compression) -> (F::QueryUpdate, u64) { + pub fn encode(&self, compression: Compression) -> (F::QueryUpdate, u64, usize) { let (deletes, nr_del) = F::encode_list(self.deletes.iter()); let (inserts, nr_ins) = F::encode_list(self.inserts.iter()); let num_rows = nr_del + nr_ins; + let num_bytes = deletes.num_bytes() + inserts.num_bytes(); let qu = QueryUpdate { deletes, inserts }; let cqu = F::into_query_update(qu, compression); - (cqu, num_rows) + (cqu, num_rows, num_bytes) } } @@ -824,17 +826,26 @@ impl ModuleHost { let auth = AuthCtx::new(replica_ctx.owner_identity, caller_identity); log::debug!("One-off query: {query}"); - db.with_read_only(Workload::Sql, |tx| { + let (rows, metrics) = db.with_read_only(Workload::Sql, |tx| { let tx = SchemaViewer::new(tx, &auth); let plan = SubscribePlan::compile(&query, &tx)?; check_row_limit(&plan, db, &tx, |plan, tx| estimate_rows_scanned(tx, plan), &auth)?; plan.execute::<_, F>(&DeltaTx::from(&*tx)) - .map(|(rows, _)| OneOffTable { - table_name: plan.table_name().to_owned().into_boxed_str(), - rows, + .map(|(rows, _, metrics)| { + ( + OneOffTable { + table_name: plan.table_name().to_owned().into_boxed_str(), + rows, + }, + metrics, + ) }) .context("One-off queries are not allowed to modify the database") - }) + })?; + + record_exec_metrics(&WorkloadType::Sql, &db.database_identity(), metrics); + + Ok(rows) } /// FIXME(jgilles): this is a temporary workaround for deleting not currently being supported diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 62f0f2a84a5..061f66537a0 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use std::time::Duration; use super::instrumentation::CallTimes; -use crate::database_logger::SystemLogger; +use crate::database_logger::{self, SystemLogger}; use crate::db::datastore::locking_tx_datastore::MutTxId; use crate::db::datastore::system_tables::{StClientRow, ST_CLIENT_ID}; use crate::db::datastore::traits::{IsolationLevel, Program}; @@ -77,6 +77,7 @@ pub struct ExecutionTimings { pub struct ExecuteResult { pub energy: EnergyStats, pub timings: ExecutionTimings, + pub memory_allocation: usize, pub call_result: Result>, E>, } @@ -193,6 +194,8 @@ impl WasmModuleHostActor { instance, info: self.info.clone(), energy_monitor: self.energy_monitor.clone(), + // will be updated on the first reducer call + allocated_memory: 0, trapped: false, } } @@ -236,6 +239,7 @@ pub struct WasmModuleInstance { instance: T, info: Arc, energy_monitor: Arc, + allocated_memory: usize, trapped: bool, } @@ -410,7 +414,7 @@ impl WasmModuleInstance { let replica_ctx = self.replica_context(); let stdb = &*replica_ctx.relational_db.clone(); - let address = replica_ctx.database_identity; + let database_identity = replica_ctx.database_identity; let reducer_def = self.info.module_def.reducer_by_id(reducer_id); let reducer_name = &*reducer_def.name; @@ -446,7 +450,7 @@ impl WasmModuleInstance { }); let _guard = WORKER_METRICS .reducer_plus_query_duration - .with_label_values(&address, op.name) + .with_label_values(&database_identity, op.name) .with_timer(tx.timer); let mut tx_slot = self.instance.instance_env().tx.clone(); @@ -466,11 +470,19 @@ impl WasmModuleInstance { let ExecuteResult { energy, timings, + memory_allocation, call_result, } = result; self.energy_monitor .record_reducer(&energy_fingerprint, energy.used, timings.total_duration); + if self.allocated_memory != memory_allocation { + WORKER_METRICS + .wasm_memory_bytes + .with_label_values(&database_identity) + .set(memory_allocation as i64); + self.allocated_memory = memory_allocation; + } reducer_span .record("timings.total_duration", tracing::field::debug(timings.total_duration)) @@ -508,6 +520,17 @@ impl WasmModuleInstance { Ok(Err(errmsg)) => { log::info!("reducer returned error: {errmsg}"); + self.replica_context().logger.write( + database_logger::LogLevel::Error, + &database_logger::Record { + ts: chrono::DateTime::from_timestamp_micros(timestamp.microseconds as i64).unwrap(), + target: Some(reducer_name), + filename: None, + line_number: None, + message: &errmsg, + }, + &(), + ); EventStatus::Failed(errmsg.into()) } // we haven't actually comitted yet - `commit_and_broadcast_event` will commit diff --git a/crates/core/src/host/wasmtime/wasmtime_module.rs b/crates/core/src/host/wasmtime/wasmtime_module.rs index 7c0d9261878..c3671b2bfb8 100644 --- a/crates/core/src/host/wasmtime/wasmtime_module.rs +++ b/crates/core/src/host/wasmtime/wasmtime_module.rs @@ -227,10 +227,12 @@ impl module_host_actor::WasmInstance for WasmtimeInstance { used: (budget - remaining).into(), remaining, }; + let memory_allocation = store.data().get_mem().memory.data_size(&store); module_host_actor::ExecuteResult { energy, timings, + memory_allocation, call_result, } } diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index 90441eed9a6..2404114f708 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -372,7 +372,6 @@ mod tests { ); run_for_testing(&db, sql)?; - let tx = db.begin_tx(Workload::ForTests); // Compile query, check for both hex formats and it to be case-insensitive... let sql = &format!( "select * from test where identity = {} AND identity_mix = x'93dda09db9a56d8fa6c024d843e805D8262191db3b4bA84c5efcd1ad451fed4e' AND address = x'{}' AND address = {}", @@ -383,6 +382,7 @@ mod tests { let rows = run_for_testing(&db, sql)?; + let tx = db.begin_tx(Workload::ForTests); let CrudExpr::Query(QueryExpr { source: _, query: mut ops, @@ -398,7 +398,7 @@ mod tests { panic!("Expected Select"); }; - assert_eq!(rows[0].data, vec![row]); + assert_eq!(rows, vec![row]); Ok(()) } diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 41962fffa1f..d8d4d670426 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1,23 +1,28 @@ use std::time::Duration; -use super::compiler::compile_sql; +use super::ast::SchemaViewer; use crate::db::datastore::locking_tx_datastore::state_view::StateView; use crate::db::datastore::system_tables::StVarTable; use crate::db::datastore::traits::IsolationLevel; use crate::db::relational_db::{RelationalDB, Tx}; use crate::energy::EnergyQuanta; use crate::error::DBError; +use crate::estimation::estimate_rows_scanned; use crate::execution_context::Workload; use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall}; use crate::host::ArgsTuple; use crate::subscription::module_subscription_actor::{ModuleSubscriptions, WriteConflict}; +use crate::subscription::tx::DeltaTx; use crate::util::slow::SlowQueryLogger; -use crate::vm::{DbProgram, TxMode}; -use itertools::Either; +use crate::vm::{check_row_limit, DbProgram, TxMode}; +use anyhow::anyhow; use spacetimedb_client_api_messages::timestamp::Timestamp; +use spacetimedb_expr::statement::Statement; use spacetimedb_lib::identity::AuthCtx; +use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::relation::FieldName; -use spacetimedb_lib::{ProductType, ProductValue}; +use spacetimedb_lib::{AlgebraicType, ProductType, ProductValue}; +use spacetimedb_query::{compile_sql_stmt, execute_dml_stmt, execute_select_stmt}; use spacetimedb_vm::eval::run_ast; use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr}; use spacetimedb_vm::relation::MemTable; @@ -172,30 +177,93 @@ pub fn run( sql_text: &str, auth: AuthCtx, subs: Option<&ModuleSubscriptions>, -) -> Result, DBError> { - let result = db.with_read_only(Workload::Sql, |tx| { - let ast = compile_sql(db, &AuthCtx::for_testing(), tx, sql_text)?; - if CrudExpr::is_reads(&ast) { - let mut updates = Vec::new(); - let result = execute( - &mut DbProgram::new(db, &mut TxMode::Tx(tx), auth), - ast, - sql_text, - &mut updates, - )?; - Ok::<_, DBError>(Either::Left(result)) - } else { - // hehe. right. write. - Ok(Either::Right(ast)) - } + head: &mut Vec<(Box, AlgebraicType)>, +) -> Result, DBError> { + // We parse the sql statement in a mutable transation. + // If it turns out to be a query, we downgrade the tx. + let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| { + compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth)) })?; - match result { - Either::Left(result) => Ok(result), - // TODO: this should perhaps be an upgradable_read upgrade? or we should try - // and figure out if we can detect the mutablility of the query before we take - // the tx? once we have migrations we probably don't want to have stale - // sql queries after a database schema have been updated. - Either::Right(ast) => execute_sql(db, sql_text, ast, auth, subs), + + let mut metrics = ExecutionMetrics::default(); + + match stmt { + Statement::Select(stmt) => { + // Up to this point, the tx has been read-only, + // and hence there are no deltas to process. + let (_, tx) = tx.commit_downgrade(Workload::Sql); + + // Release the tx on drop, so that we record metrics + let mut tx = scopeguard::guard(tx, |tx| { + db.release_tx(tx); + }); + + // Compute the header for the result set + stmt.for_each_return_field(|col_name, col_type| { + head.push((col_name.into(), col_type.clone())); + }); + + // Evaluate the query + let rows = execute_select_stmt(stmt, &DeltaTx::from(&*tx), &mut metrics, |plan| { + check_row_limit(&plan, db, &tx, |plan, tx| estimate_rows_scanned(tx, plan), &auth)?; + Ok(plan) + })?; + + // Update transaction metrics + tx.metrics.merge(metrics); + + Ok(rows) + } + Statement::DML(stmt) => { + // An extra layer of auth is required for DML + if auth.caller != auth.owner { + return Err(anyhow!("Only owners are authorized to run SQL DML statements").into()); + } + + // Evaluate the mutation + let (mut tx, _) = db.with_auto_rollback(tx, |tx| execute_dml_stmt(stmt, tx, &mut metrics))?; + + // Update transaction metrics + tx.metrics.merge(metrics); + + // Commit the tx if there are no deltas to process + if subs.is_none() { + return db.commit_tx(tx).map(|_| vec![]); + } + + // Otherwise downgrade the tx and process the deltas. + // Note, we get the delta by downgrading the tx. + // Hence we just pass a default `DatabaseUpdate` here. + // It will ultimately be replaced with the correct one. + match subs + .unwrap() + .commit_and_broadcast_event( + None, + ModuleEvent { + timestamp: Timestamp::now(), + caller_identity: auth.caller, + caller_address: None, + function_call: ModuleFunctionCall { + reducer: String::new(), + reducer_id: u32::MAX.into(), + args: ArgsTuple::default(), + }, + status: EventStatus::Committed(DatabaseUpdate::default()), + energy_quanta_used: EnergyQuanta::ZERO, + host_execution_duration: Duration::ZERO, + request_id: None, + timer: None, + }, + tx, + ) + .unwrap() + { + Err(WriteConflict) => { + todo!("See module_host_actor::call_reducer_with_tx") + } + Ok(_) => Ok(vec![]), + } + } } } @@ -218,12 +286,11 @@ pub(crate) mod tests { use pretty_assertions::assert_eq; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::error::{ResultTest, TestError}; - use spacetimedb_lib::relation::ColExpr; use spacetimedb_lib::relation::Header; use spacetimedb_lib::{AlgebraicValue, Identity}; use spacetimedb_primitives::{col_list, ColId}; use spacetimedb_sats::{product, AlgebraicType, ArrayValue, ProductType}; - use spacetimedb_vm::eval::test_helpers::{create_game_data, mem_table, mem_table_without_table_name}; + use spacetimedb_vm::eval::test_helpers::create_game_data; use std::sync::Arc; pub(crate) fn execute_for_testing( @@ -236,9 +303,9 @@ pub(crate) mod tests { } /// Short-cut for simplify test execution - pub(crate) fn run_for_testing(db: &RelationalDB, sql_text: &str) -> Result, DBError> { + pub(crate) fn run_for_testing(db: &RelationalDB, sql_text: &str) -> Result, DBError> { let subs = ModuleSubscriptions::new(Arc::new(db.clone()), Identity::ZERO); - run(db, sql_text, AuthCtx::for_testing(), Some(&subs)) + run(db, sql_text, AuthCtx::for_testing(), Some(&subs), &mut vec![]) } fn create_data(total_rows: u64) -> ResultTest<(TestDB, MemTable)> { @@ -263,14 +330,7 @@ pub(crate) mod tests { let result = run_for_testing(&db, "SELECT * FROM inventory")?; - assert_eq!(result.len(), 1, "Not return results"); - let result = result.first().unwrap().clone(); - - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, input.data, "Inventory"); Ok(()) } @@ -279,31 +339,15 @@ pub(crate) mod tests { let (db, input) = create_data(1)?; let result = run_for_testing(&db, "SELECT inventory.* FROM inventory")?; - assert_eq!(result.len(), 1, "Not return results"); - let result = result.first().unwrap().clone(); - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, input.data, "Inventory"); let result = run_for_testing( &db, "SELECT inventory.inventory_id FROM inventory WHERE inventory.inventory_id = 1", )?; - assert_eq!(result.len(), 1, "Not return results"); - let result = result.first().unwrap().clone(); - let head = ProductType::from([("inventory_id", AlgebraicType::U64)]); - let row = product!(1u64); - let input = mem_table(input.head.table_id, head, vec![row]); - - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, vec![product!(1u64)], "Inventory"); Ok(()) } @@ -313,7 +357,6 @@ pub(crate) mod tests { let (db, _) = create_data(1)?; let tx = db.begin_tx(Workload::ForTests); - let schema = db.schema_for_table(&tx, ST_TABLE_ID).unwrap(); db.release_tx(tx); let result = run_for_testing( @@ -321,8 +364,6 @@ pub(crate) mod tests { &format!("SELECT * FROM {} WHERE table_id = {}", ST_TABLE_NAME, ST_TABLE_ID), )?; - assert_eq!(result.len(), 1, "Not return results"); - let result = result.first().unwrap().clone(); let pk_col_id: ColId = StTableFields::TableId.into(); let row = product![ ST_TABLE_ID, @@ -331,108 +372,62 @@ pub(crate) mod tests { StAccess::Public.as_str(), Some(AlgebraicValue::Array(ArrayValue::U16(vec![pk_col_id.0].into()))), ]; - let input = MemTable::new(Header::from(&*schema).into(), schema.table_access, vec![row]); - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "st_table" - ); + assert_eq!(result, vec![row], "st_table"); Ok(()) } #[test] fn test_select_column() -> ResultTest<()> { - let (db, table) = create_data(1)?; + let (db, _) = create_data(1)?; let result = run_for_testing(&db, "SELECT inventory_id FROM inventory")?; - assert_eq!(result.len(), 1, "Not return results"); - let result = result.first().unwrap().clone(); - // The expected result. - let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); - let row = product![1u64]; - let input = MemTable::new(inv.into(), table.table_access, vec![row]); - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, vec![row], "Inventory"); Ok(()) } #[test] fn test_where() -> ResultTest<()> { - let (db, table) = create_data(1)?; + let (db, _) = create_data(1)?; let result = run_for_testing(&db, "SELECT inventory_id FROM inventory WHERE inventory_id = 1")?; - assert_eq!(result.len(), 1, "Not return results"); - let result = result.first().unwrap().clone(); - - // The expected result. - let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); - let row = product![1u64]; - let input = MemTable::new(inv.into(), table.table_access, vec![row]); - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, vec![row], "Inventory"); Ok(()) } #[test] fn test_or() -> ResultTest<()> { - let (db, table) = create_data(2)?; + let (db, _) = create_data(2)?; - let result = run_for_testing( + let mut result = run_for_testing( &db, "SELECT inventory_id FROM inventory WHERE inventory_id = 1 OR inventory_id = 2", )?; - assert_eq!(result.len(), 1, "Not return results"); - let mut result = result.first().unwrap().clone(); - result.data.sort(); - //The expected result - let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); - - let input = MemTable::new(inv.into(), table.table_access, vec![product![1u64], product![2u64]]); + result.sort(); - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, vec![product![1u64], product![2u64]], "Inventory"); Ok(()) } #[test] fn test_nested() -> ResultTest<()> { - let (db, table) = create_data(2)?; + let (db, _) = create_data(2)?; - let result = run_for_testing( + let mut result = run_for_testing( &db, "SELECT inventory_id FROM inventory WHERE (inventory_id = 1 OR inventory_id = 2 AND (true))", )?; - assert_eq!(result.len(), 1, "Not return results"); - let mut result = result.first().unwrap().clone(); - result.data.sort(); - // The expected result. - let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); + result.sort(); - let input = MemTable::new(inv.into(), table.table_access, vec![product![1u64], product![2u64]]); - - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, vec![product![1u64], product![2u64]], "Inventory"); Ok(()) } @@ -442,7 +437,7 @@ pub(crate) mod tests { let db = TestDB::durable()?; - let (p_schema, inv_schema) = db.with_auto_commit::<_, _, TestError>(Workload::ForTests, |tx| { + db.with_auto_commit::<_, _, TestError>(Workload::ForTests, |tx| { let i = create_table_with_rows(&db, tx, "Inventory", data.inv_ty, &data.inv.data, StAccess::Public)?; let p = create_table_with_rows(&db, tx, "Player", data.player_ty, &data.player.data, StAccess::Public)?; create_table_with_rows( @@ -456,7 +451,7 @@ pub(crate) mod tests { Ok((p, i)) })?; - let result = &run_for_testing( + let result = run_for_testing( &db, "SELECT Player.* @@ -465,18 +460,13 @@ pub(crate) mod tests { JOIN Location ON Location.entity_id = Player.entity_id WHERE Location.x > 0 AND Location.x <= 32 AND Location.z > 0 AND Location.z <= 32", - )?[0]; + )?; let row1 = product!(100u64, 1u64); - let input = MemTable::new(Header::from(&*p_schema).into(), p_schema.table_access, [row1].into()); - assert_eq!( - mem_table_without_table_name(result), - mem_table_without_table_name(&input), - "Player JOIN Location" - ); + assert_eq!(result, vec![row1], "Player JOIN Location"); - let result = &run_for_testing( + let result = run_for_testing( &db, "SELECT Inventory.* @@ -487,20 +477,11 @@ pub(crate) mod tests { JOIN Location ON Player.entity_id = Location.entity_id WHERE Location.x > 0 AND Location.x <= 32 AND Location.z > 0 AND Location.z <= 32", - )?[0]; + )?; let row1 = product!(1u64, "health"); - let input = MemTable::new( - Header::from(&*inv_schema).into(), - inv_schema.table_access, - [row1].into(), - ); - assert_eq!( - mem_table_without_table_name(result), - mem_table_without_table_name(&input), - "Inventory JOIN Player JOIN Location" - ); + assert_eq!(result, vec![row1], "Inventory JOIN Player JOIN Location"); Ok(()) } @@ -512,20 +493,13 @@ pub(crate) mod tests { assert_eq!(result.len(), 0, "Return results"); - let result = run_for_testing(&db, "SELECT * FROM inventory")?; - - assert_eq!(result.len(), 1, "Not return results"); - let mut result = result.first().unwrap().clone(); + let mut result = run_for_testing(&db, "SELECT * FROM inventory")?; input.data.push(product![2u64, "test"]); input.data.sort(); - result.data.sort(); + result.sort(); - assert_eq!( - mem_table_without_table_name(&result), - mem_table_without_table_name(&input), - "Inventory" - ); + assert_eq!(result, input.data, "Inventory"); Ok(()) } @@ -538,29 +512,17 @@ pub(crate) mod tests { run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (3, 't3')")?; let result = run_for_testing(&db, "SELECT * FROM inventory")?; - assert_eq!( - result.iter().map(|x| x.data.len()).sum::(), - 3, - "Not return results" - ); + assert_eq!(result.len(), 3, "Not return results"); run_for_testing(&db, "DELETE FROM inventory WHERE inventory.inventory_id = 3")?; let result = run_for_testing(&db, "SELECT * FROM inventory")?; - assert_eq!( - result.iter().map(|x| x.data.len()).sum::(), - 2, - "Not delete correct row?" - ); + assert_eq!(result.len(), 2, "Not delete correct row?"); run_for_testing(&db, "DELETE FROM inventory")?; let result = run_for_testing(&db, "SELECT * FROM inventory")?; - assert_eq!( - result.iter().map(|x| x.data.len()).sum::(), - 0, - "Not delete all rows" - ); + assert_eq!(result.len(), 0, "Not delete all rows"); Ok(()) } @@ -576,17 +538,11 @@ pub(crate) mod tests { let result = run_for_testing(&db, "SELECT * FROM inventory WHERE inventory_id = 2")?; - let result = result.first().unwrap().clone(); - let mut change = input; change.data.clear(); change.data.push(product![2u64, "c2"]); - assert_eq!( - mem_table_without_table_name(&change), - mem_table_without_table_name(&result), - "Update Inventory 2" - ); + assert_eq!(result, change.data, "Update Inventory 2"); run_for_testing(&db, "UPDATE inventory SET name = 'c3'")?; @@ -594,14 +550,9 @@ pub(crate) mod tests { let updated: Vec<_> = result .into_iter() - .map(|x| { - x.data - .into_iter() - .map(|x| x.field_as_str(1, None).unwrap().to_string()) - .collect::>() - }) + .map(|x| x.field_as_str(1, None).unwrap().to_string()) .collect(); - assert_eq!(vec![vec!["c3"; 3]], updated); + assert_eq!(vec!["c3"; 3], updated); Ok(()) } @@ -624,8 +575,7 @@ pub(crate) mod tests { let result = run_for_testing(&db, "select * from test where b = 1 and a = 1")?; - let result = result.first().unwrap().clone(); - assert_eq!(result.data, vec![product![1, 1, 1, 1]]); + assert_eq!(result, vec![product![1, 1, 1, 1]]); Ok(()) } @@ -672,20 +622,13 @@ pub(crate) mod tests { .unwrap(); let result = run_for_testing(&db, "select * from test where x > 5 and x < 5").unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].data.is_empty()); + assert!(result.is_empty()); let result = run_for_testing(&db, "select * from test where x >= 5 and x < 4").unwrap(); - assert_eq!(result.len(), 1); - assert!( - result[0].data.is_empty(), - "Expected no rows but found {:#?}", - result[0].data - ); + assert!(result.is_empty(), "Expected no rows but found {:#?}", result); let result = run_for_testing(&db, "select * from test where x > 5 and x <= 4").unwrap(); - assert_eq!(result.len(), 1); - assert!(result[0].data.is_empty()); + assert!(result.is_empty()); Ok(()) } @@ -703,8 +646,7 @@ pub(crate) mod tests { let result = run_for_testing(&db, "select * from test where a >= 3 and a <= 5 and b >= 3 and b <= 5")?; - let result = result.first().unwrap().clone(); - assert_eq!(result.data, []); + assert!(result.is_empty()); Ok(()) } @@ -727,6 +669,8 @@ pub(crate) mod tests { let internal_auth = AuthCtx::new(server, server); let external_auth = AuthCtx::new(server, client); + let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]); + // No row limit, both queries pass. assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok()); assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok()); diff --git a/crates/core/src/sql/parser.rs b/crates/core/src/sql/parser.rs index 8d5c096a603..3b987f5cfed 100644 --- a/crates/core/src/sql/parser.rs +++ b/crates/core/src/sql/parser.rs @@ -22,7 +22,7 @@ impl RowLevelExpr { Ok(Self { def: RowLevelSecuritySchema { - table_id: sql.table_id().unwrap(), + table_id: sql.return_table_id().unwrap(), sql: rls.sql.clone(), }, sql, diff --git a/crates/core/src/subscription/delta.rs b/crates/core/src/subscription/delta.rs index df160de544c..a30e4dd506a 100644 --- a/crates/core/src/subscription/delta.rs +++ b/crates/core/src/subscription/delta.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::Result; use spacetimedb_execution::{Datastore, DeltaStore}; +use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_query::delta::DeltaPlanEvaluator; use spacetimedb_vm::relation::RelValue; @@ -15,34 +16,35 @@ use crate::host::module_host::UpdatesRelValue; /// Hence this may be removed at any time after 1.0. pub fn eval_delta<'a, Tx: Datastore + DeltaStore>( tx: &'a Tx, + metrics: &mut ExecutionMetrics, delta: &'a DeltaPlanEvaluator, ) -> Result> { if !delta.is_join() { return Ok(UpdatesRelValue { - inserts: delta.eval_inserts(tx)?.map(RelValue::from).collect(), - deletes: delta.eval_deletes(tx)?.map(RelValue::from).collect(), + inserts: delta.eval_inserts(tx, metrics)?.map(RelValue::from).collect(), + deletes: delta.eval_deletes(tx, metrics)?.map(RelValue::from).collect(), }); } if delta.has_inserts() && !delta.has_deletes() { return Ok(UpdatesRelValue { - inserts: delta.eval_inserts(tx)?.map(RelValue::from).collect(), + inserts: delta.eval_inserts(tx, metrics)?.map(RelValue::from).collect(), deletes: vec![], }); } if delta.has_deletes() && !delta.has_inserts() { return Ok(UpdatesRelValue { - deletes: delta.eval_deletes(tx)?.map(RelValue::from).collect(), + deletes: delta.eval_deletes(tx, metrics)?.map(RelValue::from).collect(), inserts: vec![], }); } let mut inserts = HashMap::new(); - for row in delta.eval_inserts(tx)?.map(RelValue::from) { + for row in delta.eval_inserts(tx, metrics)?.map(RelValue::from) { inserts.entry(row).and_modify(|n| *n += 1).or_insert(1); } let deletes = delta - .eval_deletes(tx)? + .eval_deletes(tx, metrics)? .map(RelValue::from) .filter(|row| match inserts.get_mut(row) { None => true, diff --git a/crates/core/src/subscription/mod.rs b/crates/core/src/subscription/mod.rs index e9bb8e519e4..6c7a167aa79 100644 --- a/crates/core/src/subscription/mod.rs +++ b/crates/core/src/subscription/mod.rs @@ -1,3 +1,7 @@ +use spacetimedb_lib::{metrics::ExecutionMetrics, Identity}; + +use crate::{db::db_metrics::DB_METRICS, execution_context::WorkloadType, worker_metrics::WORKER_METRICS}; + pub mod delta; pub mod execution_unit; pub mod module_subscription_actor; @@ -6,3 +10,27 @@ pub mod query; #[allow(clippy::module_inception)] // it's right this isn't ideal :/ pub mod subscription; pub mod tx; + +/// Update the global system metrics with transaction-level execution metrics +pub(crate) fn record_exec_metrics(workload: &WorkloadType, db: &Identity, metrics: ExecutionMetrics) { + DB_METRICS + .rdb_num_index_seeks + .with_label_values(workload, db) + .inc_by(metrics.index_seeks as u64); + DB_METRICS + .rdb_num_rows_scanned + .with_label_values(workload, db) + .inc_by(metrics.rows_scanned as u64); + DB_METRICS + .rdb_num_bytes_scanned + .with_label_values(workload, db) + .inc_by(metrics.bytes_scanned as u64); + DB_METRICS + .rdb_num_bytes_written + .with_label_values(workload, db) + .inc_by(metrics.bytes_written as u64); + WORKER_METRICS + .bytes_sent_to_clients + .with_label_values(workload, db) + .inc_by(metrics.bytes_sent_to_clients as u64); +} diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index a9976dc0965..26eb26e9908 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,6 +1,7 @@ use super::execution_unit::QueryHash; use super::module_subscription_manager::{Plan, SubscriptionManager}; use super::query::compile_read_only_query; +use super::record_exec_metrics; use super::tx::DeltaTx; use crate::client::messages::{ SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionRows, SubscriptionUpdateMessage, @@ -11,7 +12,7 @@ use crate::db::datastore::locking_tx_datastore::tx::TxId; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::DBError; use crate::estimation::estimate_rows_scanned; -use crate::execution_context::Workload; +use crate::execution_context::{Workload, WorkloadType}; use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent}; use crate::messages::websocket::Subscribe; use crate::vm::check_row_limit; @@ -21,6 +22,7 @@ use spacetimedb_client_api_messages::websocket::{ BsatnFormat, FormatSwitch, JsonFormat, SubscribeSingle, TableUpdate, Unsubscribe, }; use spacetimedb_lib::identity::AuthCtx; +use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Identity; use spacetimedb_query::{execute_plans, SubscribePlan}; use std::{sync::Arc, time::Instant}; @@ -37,6 +39,7 @@ pub struct ModuleSubscriptions { } type AssertTxFn = Arc; +type SubscriptionUpdate = FormatSwitch, TableUpdate>; impl ModuleSubscriptions { pub fn new(relational_db: Arc, owner_identity: Identity) -> Self { @@ -54,7 +57,7 @@ impl ModuleSubscriptions { query: Arc, tx: &TxId, auth: &AuthCtx, - ) -> Result, TableUpdate>, DBError> { + ) -> Result<(SubscriptionUpdate, ExecutionMetrics), DBError> { let comp = sender.config.compression; let plan = SubscribePlan::from_delta_plan(&query); @@ -68,8 +71,12 @@ impl ModuleSubscriptions { let tx = DeltaTx::from(tx); Ok(match sender.config.protocol { - Protocol::Binary => FormatSwitch::Bsatn(plan.collect_table_update(comp, &tx)?), - Protocol::Text => FormatSwitch::Json(plan.collect_table_update(comp, &tx)?), + Protocol::Binary => plan + .collect_table_update(comp, &tx) + .map(|(table_update, metrics)| (FormatSwitch::Bsatn(table_update), metrics))?, + Protocol::Text => plan + .collect_table_update(comp, &tx) + .map(|(table_update, metrics)| (FormatSwitch::Json(table_update), metrics))?, }) } @@ -114,7 +121,13 @@ impl ModuleSubscriptions { } }; - let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth)?; + let (table_rows, metrics) = self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth)?; + + record_exec_metrics( + &WorkloadType::Subscribe, + &self.relational_db.database_identity(), + metrics, + ); // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently. // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here @@ -177,7 +190,13 @@ impl ModuleSubscriptions { self.relational_db.release_tx(tx); }); let auth = AuthCtx::new(self.owner_identity, sender.id.identity); - let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth)?; + let (table_rows, metrics) = self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth)?; + + record_exec_metrics( + &WorkloadType::Subscribe, + &self.relational_db.database_identity(), + metrics, + ); WORKER_METRICS .subscription_queries @@ -263,11 +282,19 @@ impl ModuleSubscriptions { )?; let tx = DeltaTx::from(&*tx); - let database_update = match sender.config.protocol { - Protocol::Text => FormatSwitch::Json(execute_plans(plans, comp, &tx)?), - Protocol::Binary => FormatSwitch::Bsatn(execute_plans(plans, comp, &tx)?), + let (database_update, metrics) = match sender.config.protocol { + Protocol::Binary => execute_plans(plans, comp, &tx) + .map(|(table_update, metrics)| (FormatSwitch::Bsatn(table_update), metrics))?, + Protocol::Text => execute_plans(plans, comp, &tx) + .map(|(table_update, metrics)| (FormatSwitch::Json(table_update), metrics))?, }; + record_exec_metrics( + &WorkloadType::Subscribe, + &self.relational_db.database_identity(), + metrics, + ); + // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently. // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here // but that should not pose an issue. @@ -344,7 +371,9 @@ impl ModuleSubscriptions { let event = Arc::new(event); match &event.status { - EventStatus::Committed(_) => subscriptions.eval_updates(&read_tx, event.clone(), caller), + EventStatus::Committed(_) => { + subscriptions.eval_updates(&read_tx, event.clone(), caller, &self.relational_db.database_identity()) + } EventStatus::Failed(_) => { if let Some(client) = caller { let message = TransactionUpdateMessage { diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index c61f400416d..0dc94726727 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -3,15 +3,18 @@ use super::tx::DeltaTx; use crate::client::messages::{SubscriptionUpdateMessage, TransactionUpdateMessage}; use crate::client::{ClientConnectionSender, Protocol}; use crate::error::DBError; +use crate::execution_context::WorkloadType; use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue}; use crate::messages::websocket::{self as ws, TableUpdate}; use crate::subscription::delta::eval_delta; +use crate::subscription::record_exec_metrics; use hashbrown::hash_map::OccupiedError; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryId, QueryUpdate, WebsocketFormat, }; use spacetimedb_data_structures::map::{Entry, HashCollectionExt, HashMap, HashSet, IntMap}; +use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::{Address, Identity}; use spacetimedb_primitives::TableId; use spacetimedb_query::delta::DeltaPlan; @@ -315,7 +318,13 @@ impl SubscriptionManager { /// evaluates only the necessary queries for those delta tables, /// and then sends the results to each client. #[tracing::instrument(level = "trace", skip_all)] - pub fn eval_updates(&self, tx: &DeltaTx, event: Arc, caller: Option<&ClientConnectionSender>) { + pub fn eval_updates( + &self, + tx: &DeltaTx, + event: Arc, + caller: Option<&ClientConnectionSender>, + database_identity: &Identity, + ) { use FormatSwitch::{Bsatn, Json}; let tables = &event.status.database_update().unwrap().tables; @@ -323,7 +332,7 @@ impl SubscriptionManager { // Put the main work on a rayon compute thread. rayon::scope(|_| { let span = tracing::info_span!("eval_incr").entered(); - let mut eval = tables + let (updates, metrics) = tables .iter() .filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty()) .map(|DatabaseTableUpdate { table_id, .. }| table_id) @@ -335,7 +344,7 @@ impl SubscriptionManager { // If N clients are subscribed to a query, // we copy the DatabaseTableUpdate N times, // which involves cloning BSATN (binary) or product values (json). - .flat_map_iter(|(hash, plan)| { + .map(|(hash, plan)| { let table_id = plan.table_id(); let table_name = plan.table_name(); // Store at most one copy of the serialization to BSATN @@ -344,23 +353,38 @@ impl SubscriptionManager { // but we only fill `ops_bin` and `ops_json` at most once. // The former will be `Some(_)` if some subscriber uses `Protocol::Binary` // and the latter `Some(_)` if some subscriber uses `Protocol::Text`. - let mut ops_bin: Option<(CompressableQueryUpdate, _)> = None; - let mut ops_json: Option<(QueryUpdate, _)> = None; + let mut ops_bin: Option<(CompressableQueryUpdate, _, _)> = None; + let mut ops_json: Option<(QueryUpdate, _, _)> = None; fn memo_encode( updates: &UpdatesRelValue<'_>, client: &ClientConnectionSender, - memory: &mut Option<(F::QueryUpdate, u64)>, + memory: &mut Option<(F::QueryUpdate, u64, usize)>, + metrics: &mut ExecutionMetrics, ) -> (F::QueryUpdate, u64) { - memory - .get_or_insert_with(|| updates.encode::(client.config.compression)) - .clone() + let (update, num_rows, num_bytes) = memory + .get_or_insert_with(|| { + let encoded = updates.encode::(client.config.compression); + // The first time we insert into this map, we call encode. + // This is when we serialize the rows to BSATN/JSON. + // Hence this is where we increment `bytes_scanned`. + metrics.bytes_scanned += encoded.2; + encoded + }) + .clone(); + // We call this function for each query, + // and for each client subscribed to it. + // Therefore every time we call this function, + // we update the `bytes_sent_to_clients` metric. + metrics.bytes_sent_to_clients += num_bytes; + (update, num_rows) } let evaluator = plan.evaluator(tx); + let mut metrics = ExecutionMetrics::default(); // TODO: Handle errors instead of skipping them - eval_delta(tx, &evaluator) + let updates = eval_delta(tx, &mut metrics, &evaluator) .ok() .filter(|delta_updates| delta_updates.has_updates()) .map(|delta_updates| { @@ -368,23 +392,42 @@ impl SubscriptionManager { .get(hash) .into_iter() .flat_map(|query| query.all_clients()) - .map(move |id| { + .map(|id| { let client = &self.clients[id].outbound_ref; let update = match client.config.protocol { - Protocol::Binary => { - Bsatn(memo_encode::(&delta_updates, client, &mut ops_bin)) - } - Protocol::Text => { - Json(memo_encode::(&delta_updates, client, &mut ops_json)) - } + Protocol::Binary => Bsatn(memo_encode::( + &delta_updates, + client, + &mut ops_bin, + &mut metrics, + )), + Protocol::Text => Json(memo_encode::( + &delta_updates, + client, + &mut ops_json, + &mut metrics, + )), }; (id, table_id, table_name.clone(), update) }) .collect::>() }) - .unwrap_or_default() + .unwrap_or_default(); + + (updates, metrics) }) - .collect::>() + .reduce( + || (vec![], ExecutionMetrics::default()), + |(mut updates, mut aggregated_metrics), (table_upates, metrics)| { + updates.extend(table_upates); + aggregated_metrics.merge(metrics); + (updates, aggregated_metrics) + }, + ); + + record_exec_metrics(&WorkloadType::Update, database_identity, metrics); + + let mut eval = updates .into_iter() // For each subscriber, aggregate all the updates for the same table. // That is, we build a map `(subscriber_id, table_id) -> updates`. @@ -956,7 +999,7 @@ mod tests { }); db.with_read_only(Workload::Update, |tx| { - subscriptions.eval_updates(&(&*tx).into(), event, Some(&client0)) + subscriptions.eval_updates(&(&*tx).into(), event, Some(&client0), &db.database_identity()) }); tokio::runtime::Builder::new_current_thread() diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 2dd0ba71da8..cc291378660 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -132,6 +132,7 @@ mod tests { use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::error::ResultTest; use spacetimedb_lib::identity::AuthCtx; + use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::relation::FieldName; use spacetimedb_lib::Identity; use spacetimedb_primitives::{ColId, TableId}; @@ -733,6 +734,7 @@ mod tests { fn eval_incr( db: &RelationalDB, + metrics: &mut ExecutionMetrics, plan: &DeltaPlan, ops: Vec<(TableId, ProductValue, bool)>, ) -> ResultTest { @@ -751,7 +753,7 @@ mod tests { let table_name = plan.table_name(); let tx = DeltaTx::new(&tx, &data); let evaluator = plan.evaluator(&tx); - let updates = eval_delta(&tx, &evaluator).unwrap(); + let updates = eval_delta(&tx, metrics, &evaluator).unwrap(); let inserts = updates .inserts @@ -788,10 +790,17 @@ mod tests { let r1 = product!(10, 0, 2); let r2 = product!(10, 0, 3); - let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr(db, &mut metrics, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // No updates to report assert!(result.is_empty()); + + // The lhs row must always probe the rhs index. + // The rhs row passes the rhs filter, + // resulting in a probe of the rhs index. + assert_eq!(metrics.index_seeks, 2); Ok(()) } @@ -806,10 +815,17 @@ mod tests { let r1 = product!(13, 3, 5); let r2 = product!(13, 3, 6); - let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr(db, &mut metrics, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // No updates to report assert!(result.is_empty()); + + // The lhs row must always probe the rhs index. + // The rhs row doesn't pass the rhs filter, + // hence it doesn't survive to probe the lhs index. + assert_eq!(metrics.index_seeks, 0); Ok(()) } @@ -824,11 +840,17 @@ mod tests { let r1 = product!(10, 0, 2); let r2 = product!(10, 0, 5); - let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr(db, &mut metrics, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // A single delete from lhs assert_eq!(result.tables.len(), 1); assert_eq!(result.tables[0], delete_op(lhs_id, "lhs", product!(0, 5))); + + // One row passes the rhs filter, the other does not. + // This results in a single probe of the lhs index. + assert_eq!(metrics.index_seeks, 1); Ok(()) } @@ -843,11 +865,17 @@ mod tests { let r1 = product!(13, 3, 5); let r2 = product!(13, 3, 4); - let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr(db, &mut metrics, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // A single insert into lhs assert_eq!(result.tables.len(), 1); assert_eq!(result.tables[0], insert_op(lhs_id, "lhs", product!(3, 8))); + + // One row passes the rhs filter, the other does not. + // This results in a single probe of the lhs index. + assert_eq!(metrics.index_seeks, 1); Ok(()) } @@ -862,11 +890,23 @@ mod tests { let lhs_row = product!(5, 10); let rhs_row = product!(20, 5, 3); - let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, true), (rhs_id, rhs_row, true)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr( + db, + &mut metrics, + &query, + vec![(lhs_id, lhs_row, true), (rhs_id, rhs_row, true)], + )?; // A single insert into lhs assert_eq!(result.tables.len(), 1); assert_eq!(result.tables[0], insert_op(lhs_id, "lhs", product!(5, 10))); + + // The lhs row must always probe the rhs index. + // The rhs row passes the rhs filter, + // resulting in a probe of the rhs index. + assert_eq!(metrics.index_seeks, 2); Ok(()) } @@ -881,10 +921,22 @@ mod tests { let lhs_row = product!(5, 10); let rhs_row = product!(20, 5, 5); - let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, true), (rhs_id, rhs_row, true)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr( + db, + &mut metrics, + &query, + vec![(lhs_id, lhs_row, true), (rhs_id, rhs_row, true)], + )?; // No updates to report assert_eq!(result.tables.len(), 0); + + // The lhs row must always probe the rhs index. + // The rhs row doesn't pass the rhs filter, + // hence it doesn't survive to probe the lhs index. + assert_eq!(metrics.index_seeks, 1); Ok(()) } @@ -899,11 +951,23 @@ mod tests { let lhs_row = product!(0, 5); let rhs_row = product!(10, 0, 2); - let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, false), (rhs_id, rhs_row, false)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr( + db, + &mut metrics, + &query, + vec![(lhs_id, lhs_row, false), (rhs_id, rhs_row, false)], + )?; // A single delete from lhs assert_eq!(result.tables.len(), 1); assert_eq!(result.tables[0], delete_op(lhs_id, "lhs", product!(0, 5))); + + // The lhs row must always probe the rhs index. + // The rhs row passes the rhs filter, + // resulting in a probe of the rhs index. + assert_eq!(metrics.index_seeks, 2); Ok(()) } @@ -918,10 +982,22 @@ mod tests { let lhs_row = product!(3, 8); let rhs_row = product!(13, 3, 5); - let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, false), (rhs_id, rhs_row, false)])?; + let mut metrics = ExecutionMetrics::default(); + + let result = eval_incr( + db, + &mut metrics, + &query, + vec![(lhs_id, lhs_row, false), (rhs_id, rhs_row, false)], + )?; // No updates to report assert_eq!(result.tables.len(), 0); + + // The lhs row must always probe the rhs index. + // The rhs row doesn't pass the rhs filter, + // hence it doesn't survive to probe the lhs index. + assert_eq!(metrics.index_seeks, 1); Ok(()) } @@ -938,8 +1014,11 @@ mod tests { let rhs_old = product!(11, 1, 3); let rhs_new = product!(11, 1, 4); + let mut metrics = ExecutionMetrics::default(); + let result = eval_incr( db, + &mut metrics, &query, vec![ (lhs_id, lhs_old, false), @@ -963,6 +1042,11 @@ mod tests { inserts: [lhs_new].into(), }, ); + + // The lhs rows must always probe the rhs index. + // The rhs rows pass the rhs filter, + // resulting in probes of the rhs index. + assert_eq!(metrics.index_seeks, 4); Ok(()) } } diff --git a/crates/core/src/util/slow.rs b/crates/core/src/util/slow.rs index b9d56a58ca2..080348cc4b5 100644 --- a/crates/core/src/util/slow.rs +++ b/crates/core/src/util/slow.rs @@ -49,11 +49,12 @@ mod tests { use super::*; use crate::db::datastore::system_tables::ST_VARNAME_SLOW_QRY; - use crate::db::datastore::system_tables::{StVarName, StVarValue, ST_VARNAME_SLOW_INC, ST_VARNAME_SLOW_SUB}; + use crate::db::datastore::system_tables::{StVarName, ST_VARNAME_SLOW_INC, ST_VARNAME_SLOW_SUB}; use crate::sql::compiler::compile_sql; use crate::sql::execute::tests::execute_for_testing; use spacetimedb_lib::error::ResultTest; use spacetimedb_lib::identity::AuthCtx; + use spacetimedb_lib::st_var::StVarValue; use spacetimedb_lib::ProductValue; use crate::db::relational_db::tests_utils::{insert, TestDB}; diff --git a/crates/core/src/worker_metrics/mod.rs b/crates/core/src/worker_metrics/mod.rs index 8dbd3e5b1ae..666353fff06 100644 --- a/crates/core/src/worker_metrics/mod.rs +++ b/crates/core/src/worker_metrics/mod.rs @@ -85,6 +85,11 @@ metrics_group!( #[labels(caller_identity: Identity, module_hash: Hash, caller_address: Address, reducer_symbol: str)] pub wasm_instance_errors: IntCounterVec, + #[name = spacetime_worker_wasm_memory_bytes] + #[help = "The number of bytes of linear memory allocated by the database's WASM module instance"] + #[labels(database_identity: Identity)] + pub wasm_memory_bytes: IntGaugeVec, + #[name = spacetime_active_queries] #[help = "The number of active subscription queries"] #[labels(database_identity: Identity)] @@ -99,6 +104,11 @@ metrics_group!( #[help = "The time spent executing a reducer (in seconds), plus the time spent evaluating its subscription queries"] #[labels(db: Identity, reducer: str)] pub reducer_plus_query_duration: HistogramVec, + + #[name = spacetime_num_bytes_sent_to_clients_total] + #[help = "The cumulative number of bytes sent to clients"] + #[labels(txn_type: WorkloadType, db: Identity)] + pub bytes_sent_to_clients: IntCounterVec, } ); diff --git a/crates/execution/Cargo.toml b/crates/execution/Cargo.toml index 530ad8e4132..614ef80a23d 100644 --- a/crates/execution/Cargo.toml +++ b/crates/execution/Cargo.toml @@ -10,6 +10,7 @@ description = "The SpacetimeDB query engine" anyhow.workspace = true spacetimedb-expr.workspace = true spacetimedb-lib.workspace = true +spacetimedb-sats.workspace = true spacetimedb-physical-plan.workspace = true spacetimedb-primitives.workspace = true spacetimedb-sql-parser.workspace = true diff --git a/crates/execution/src/dml.rs b/crates/execution/src/dml.rs new file mode 100644 index 00000000000..17a2a942321 --- /dev/null +++ b/crates/execution/src/dml.rs @@ -0,0 +1,152 @@ +use anyhow::Result; +use spacetimedb_lib::{metrics::ExecutionMetrics, AlgebraicValue, ProductValue}; +use spacetimedb_physical_plan::dml::{DeletePlan, InsertPlan, MutationPlan, UpdatePlan}; +use spacetimedb_primitives::{ColId, TableId}; +use spacetimedb_sats::size_of::SizeOf; + +use crate::{pipelined::PipelinedProject, Datastore, DeltaStore}; + +/// A mutable datastore can read as well as insert and delete rows +pub trait MutDatastore: Datastore + DeltaStore { + fn insert_product_value(&mut self, table_id: TableId, row: &ProductValue) -> Result<()>; + fn delete_product_value(&mut self, table_id: TableId, row: &ProductValue) -> Result<()>; +} + +/// Executes a physical mutation plan +pub enum MutExecutor { + Insert(InsertExecutor), + Delete(DeleteExecutor), + Update(UpdateExecutor), +} + +impl From for MutExecutor { + fn from(plan: MutationPlan) -> Self { + match plan { + MutationPlan::Insert(plan) => Self::Insert(plan.into()), + MutationPlan::Delete(plan) => Self::Delete(plan.into()), + MutationPlan::Update(plan) => Self::Update(plan.into()), + } + } +} + +impl MutExecutor { + pub fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + match self { + Self::Insert(exec) => exec.execute(tx, metrics), + Self::Delete(exec) => exec.execute(tx, metrics), + Self::Update(exec) => exec.execute(tx, metrics), + } + } +} + +/// Executes row insertions +pub struct InsertExecutor { + table_id: TableId, + rows: Vec, +} + +impl From for InsertExecutor { + fn from(plan: InsertPlan) -> Self { + Self { + rows: plan.rows, + table_id: plan.table.table_id, + } + } +} + +impl InsertExecutor { + fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + for row in &self.rows { + tx.insert_product_value(self.table_id, row)?; + } + // TODO: It would be better to get this metric from the bsatn buffer. + // But we haven't been concerned with optimizing DML up to this point. + metrics.bytes_written += self.rows.iter().map(|row| row.size_of()).sum::(); + Ok(()) + } +} + +/// Executes row deletions +pub struct DeleteExecutor { + table_id: TableId, + filter: PipelinedProject, +} + +impl From for DeleteExecutor { + fn from(plan: DeletePlan) -> Self { + Self { + table_id: plan.table.table_id, + filter: plan.filter.into(), + } + } +} + +impl DeleteExecutor { + fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + // TODO: Delete by row id instead of product value + let mut deletes = vec![]; + self.filter.execute(tx, metrics, &mut |row| { + deletes.push(row.to_product_value()); + Ok(()) + })?; + // TODO: This metric should be updated inline when we serialize. + // Note, that we don't update bytes written, + // because deletes don't actually write out any bytes. + metrics.bytes_scanned += deletes.iter().map(|row| row.size_of()).sum::(); + for row in &deletes { + tx.delete_product_value(self.table_id, row)?; + } + Ok(()) + } +} + +/// Executes row updates +pub struct UpdateExecutor { + table_id: TableId, + columns: Vec<(ColId, AlgebraicValue)>, + filter: PipelinedProject, +} + +impl From for UpdateExecutor { + fn from(plan: UpdatePlan) -> Self { + Self { + columns: plan.columns, + table_id: plan.table.table_id, + filter: plan.filter.into(), + } + } +} + +impl UpdateExecutor { + fn execute(&self, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + let mut deletes = vec![]; + self.filter.execute(tx, metrics, &mut |row| { + deletes.push(row.to_product_value()); + Ok(()) + })?; + for row in &deletes { + tx.delete_product_value(self.table_id, row)?; + } + // TODO: This metric should be updated inline when we serialize. + metrics.bytes_scanned = deletes.iter().map(|row| row.size_of()).sum::(); + for row in &deletes { + let row = ProductValue::from_iter( + row + // Update the deleted rows with the new field values + .into_iter() + .cloned() + .enumerate() + .map(|(i, elem)| { + self.columns + .iter() + .find(|(col_id, _)| i == col_id.idx()) + .map(|(_, value)| value.clone()) + .unwrap_or_else(|| elem) + }), + ); + tx.insert_product_value(self.table_id, &row)?; + metrics.bytes_written += row.size_of(); + } + Ok(()) + } +} diff --git a/crates/execution/src/lib.rs b/crates/execution/src/lib.rs index 474892557b3..9d636547c95 100644 --- a/crates/execution/src/lib.rs +++ b/crates/execution/src/lib.rs @@ -16,6 +16,7 @@ use spacetimedb_table::{ table::{IndexScanIter, RowRef, Table, TableScanIter}, }; +pub mod dml; pub mod iter; pub mod pipelined; @@ -79,6 +80,15 @@ pub enum Row<'a> { Ref(&'a ProductValue), } +impl Row<'_> { + pub fn to_product_value(&self) -> ProductValue { + match self { + Self::Ptr(ptr) => ptr.to_product_value(), + Self::Ref(val) => (*val).clone(), + } + } +} + impl_serialize!(['a] Row<'a>, (self, ser) => match self { Self::Ptr(row) => row.serialize(ser), Self::Ref(row) => row.serialize(ser), diff --git a/crates/execution/src/pipelined.rs b/crates/execution/src/pipelined.rs index 4c362242b53..d8491d2c347 100644 --- a/crates/execution/src/pipelined.rs +++ b/crates/execution/src/pipelined.rs @@ -4,14 +4,63 @@ use std::{ }; use anyhow::{anyhow, Result}; -use spacetimedb_lib::{query::Delta, AlgebraicValue, ProductValue}; +use spacetimedb_lib::{metrics::ExecutionMetrics, query::Delta, sats::size_of::SizeOf, AlgebraicValue, ProductValue}; use spacetimedb_physical_plan::plan::{ - HashJoin, IxJoin, IxScan, PhysicalExpr, PhysicalPlan, ProjectField, ProjectPlan, Sarg, Semi, TupleField, + HashJoin, IxJoin, IxScan, PhysicalExpr, PhysicalPlan, ProjectField, ProjectListPlan, ProjectPlan, Sarg, Semi, + TupleField, }; use spacetimedb_primitives::{ColId, IndexId, TableId}; use crate::{Datastore, DeltaStore, Row, Tuple}; +/// An executor for explicit column projections. +/// Note, this plan can only be constructed from the http api, +/// which is not considered performance critical. +/// Hence this operator is not particularly optimized. +pub enum ProjectListExecutor { + Name(PipelinedProject), + List(PipelinedExecutor, Vec), +} + +impl From for ProjectListExecutor { + fn from(plan: ProjectListPlan) -> Self { + match plan { + ProjectListPlan::Name(plan) => Self::Name(plan.into()), + ProjectListPlan::List(plan, fields) => Self::List(plan.into(), fields), + } + } +} + +impl ProjectListExecutor { + pub fn execute( + &self, + tx: &Tx, + metrics: &mut ExecutionMetrics, + f: &mut dyn FnMut(ProductValue) -> Result<()>, + ) -> Result<()> { + let mut n = 0; + let mut bytes_scanned = 0; + let mut f = |row: ProductValue| { + n += 1; + bytes_scanned += row.size_of(); + f(row) + }; + match self { + Self::Name(plan) => { + plan.execute(tx, metrics, &mut |row| f(row.to_product_value()))?; + } + Self::List(plan, fields) => { + plan.execute(tx, metrics, &mut |t| { + f(ProductValue::from_iter(fields.iter().map(|field| t.project(field)))) + })?; + } + } + metrics.rows_scanned += n; + metrics.bytes_scanned += bytes_scanned; + Ok(()) + } +} + /// Implements a projection on top of a pipelined executor pub enum PipelinedProject { None(PipelinedExecutor), @@ -32,31 +81,37 @@ impl PipelinedProject { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Row<'a>) -> Result<()>, ) -> Result<()> { + let mut n = 0; match self { Self::None(plan) => { // No explicit projection. // This means the input does not return tuples. // It returns either row ids or product values. - plan.execute(tx, &mut |t| { + plan.execute(tx, metrics, &mut |t| { + n += 1; if let Tuple::Row(row) = t { f(row)?; } Ok(()) - }) + })?; } Self::Some(plan, i) => { // The contrary is true for explicit projections. // They return a tuple of row ids or product values. - plan.execute(tx, &mut |t| { + plan.execute(tx, metrics, &mut |t| { + n += 1; if let Some(row) = t.select(*i) { f(row)?; } Ok(()) - }) + })?; } } + metrics.rows_scanned += n; + Ok(()) } } @@ -134,15 +189,16 @@ impl PipelinedExecutor { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { match self { - Self::TableScan(scan) => scan.execute(tx, f), - Self::IxScan(scan) => scan.execute(tx, f), - Self::IxJoin(join) => join.execute(tx, f), - Self::HashJoin(join) => join.execute(tx, f), - Self::NLJoin(join) => join.execute(tx, f), - Self::Filter(filter) => filter.execute(tx, f), + Self::TableScan(scan) => scan.execute(tx, metrics, f), + Self::IxScan(scan) => scan.execute(tx, metrics, f), + Self::IxJoin(join) => join.execute(tx, metrics, f), + Self::HashJoin(join) => join.execute(tx, metrics, f), + Self::NLJoin(join) => join.execute(tx, metrics, f), + Self::Filter(filter) => filter.execute(tx, metrics, f), } } } @@ -157,8 +213,14 @@ impl PipelinedScan { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { + let mut n = 0; + let mut f = |t| { + n += 1; + f(t) + }; match self.delta { None => { for tuple in tx @@ -191,6 +253,7 @@ impl PipelinedScan { } } } + metrics.rows_scanned += n; Ok(()) } } @@ -244,8 +307,14 @@ impl PipelinedIxScan { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { + let mut n = 0; + let mut f = |t| { + n += 1; + f(t) + }; match self.prefix.as_slice() { [] => { for ptr in tx @@ -289,6 +358,8 @@ impl PipelinedIxScan { } } } + metrics.index_seeks += 1; + metrics.rows_scanned += n; Ok(()) } } @@ -315,6 +386,7 @@ impl PipelinedIxJoin { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { let blob_store = tx.blob_store(); @@ -323,6 +395,10 @@ impl PipelinedIxJoin { .get_index_by_id(self.rhs_index) .ok_or_else(|| anyhow!("IndexId `{0}` does not exist", self.rhs_index))?; + let mut n = 0; + let mut index_seeks = 0; + let mut bytes_scanned = 0; + match self { Self { lhs, @@ -333,12 +409,14 @@ impl PipelinedIxJoin { } => { // Should we evaluate the lhs tuple? // Probe the index to see if there is a matching row. - lhs.execute(tx, &mut |u| { - if rhs_index.contains_any(&u.project(lhs_field)) { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + index_seeks += 1; + if rhs_index.contains_any(&project(&u, lhs_field, &mut bytes_scanned)) { f(u)?; } Ok(()) - }) + })?; } Self { lhs, @@ -348,9 +426,11 @@ impl PipelinedIxJoin { .. } => { // Probe the index and evaluate the matching rhs row - lhs.execute(tx, &mut |u| { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + index_seeks += 1; if let Some(v) = rhs_index - .seek(&u.project(lhs_field)) + .seek(&project(&u, lhs_field, &mut bytes_scanned)) .next() .and_then(|ptr| rhs_table.get_row_ref(blob_store, ptr)) .map(Row::Ptr) @@ -359,7 +439,7 @@ impl PipelinedIxJoin { f(v)?; } Ok(()) - }) + })?; } Self { lhs, @@ -369,9 +449,11 @@ impl PipelinedIxJoin { .. } => { // Probe the index and evaluate the matching rhs row - lhs.execute(tx, &mut |u| { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + index_seeks += 1; if let Some(v) = rhs_index - .seek(&u.project(lhs_field)) + .seek(&project(&u, lhs_field, &mut bytes_scanned)) .next() .and_then(|ptr| rhs_table.get_row_ref(blob_store, ptr)) .map(Row::Ptr) @@ -380,7 +462,7 @@ impl PipelinedIxJoin { f(u.join(v))?; } Ok(()) - }) + })?; } Self { lhs, @@ -391,14 +473,16 @@ impl PipelinedIxJoin { } => { // How many times should we evaluate the lhs tuple? // Probe the index for the number of matching rows. - lhs.execute(tx, &mut |u| { - if let Some(n) = rhs_index.count(&u.project(lhs_field)) { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + index_seeks += 1; + if let Some(n) = rhs_index.count(&project(&u, lhs_field, &mut bytes_scanned)) { for _ in 0..n { f(u.clone())?; } } Ok(()) - }) + })?; } Self { lhs, @@ -408,9 +492,11 @@ impl PipelinedIxJoin { .. } => { // Probe the index and evaluate the matching rhs rows - lhs.execute(tx, &mut |u| { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + index_seeks += 1; for v in rhs_index - .seek(&u.project(lhs_field)) + .seek(&project(&u, lhs_field, &mut bytes_scanned)) .filter_map(|ptr| rhs_table.get_row_ref(blob_store, ptr)) .map(Row::Ptr) .map(Tuple::Row) @@ -418,7 +504,7 @@ impl PipelinedIxJoin { f(v)?; } Ok(()) - }) + })?; } Self { lhs, @@ -428,9 +514,11 @@ impl PipelinedIxJoin { .. } => { // Probe the index and evaluate the matching rhs rows - lhs.execute(tx, &mut |u| { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + index_seeks += 1; for v in rhs_index - .seek(&u.project(lhs_field)) + .seek(&project(&u, lhs_field, &mut bytes_scanned)) .filter_map(|ptr| rhs_table.get_row_ref(blob_store, ptr)) .map(Row::Ptr) .map(Tuple::Row) @@ -438,9 +526,13 @@ impl PipelinedIxJoin { f(u.clone().join(v.clone()))?; } Ok(()) - }) + })?; } } + metrics.index_seeks += index_seeks; + metrics.rows_scanned += n; + metrics.bytes_scanned += bytes_scanned; + Ok(()) } } @@ -460,8 +552,11 @@ impl BlockingHashJoin { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { + let mut n = 0; + let mut bytes_scanned = 0; match self { Self { lhs, @@ -472,16 +567,21 @@ impl BlockingHashJoin { semijoin: Semi::Lhs, } => { let mut rhs_table = HashSet::new(); - rhs.execute(tx, &mut |v| { - rhs_table.insert(v.project(rhs_field)); + rhs.execute(tx, metrics, &mut |v| { + rhs_table.insert(project(&v, rhs_field, &mut bytes_scanned)); Ok(()) })?; - lhs.execute(tx, &mut |u| { - if rhs_table.contains(&u.project(lhs_field)) { + + // How many rows did we pull from the rhs? + n += rhs_table.len(); + + lhs.execute(tx, metrics, &mut |u| { + n += 1; + if rhs_table.contains(&project(&u, lhs_field, &mut bytes_scanned)) { f(u)?; } Ok(()) - }) + })?; } Self { lhs, @@ -492,16 +592,21 @@ impl BlockingHashJoin { semijoin: Semi::Rhs, } => { let mut rhs_table = HashMap::new(); - rhs.execute(tx, &mut |v| { - rhs_table.insert(v.project(rhs_field), v); + rhs.execute(tx, metrics, &mut |v| { + rhs_table.insert(project(&v, rhs_field, &mut bytes_scanned), v); Ok(()) })?; - lhs.execute(tx, &mut |u| { - if let Some(v) = rhs_table.get(&u.project(lhs_field)) { + + // How many rows did we pull from the rhs? + n += rhs_table.len(); + + lhs.execute(tx, metrics, &mut |u| { + n += 1; + if let Some(v) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)) { f(v.clone())?; } Ok(()) - }) + })?; } Self { lhs, @@ -512,16 +617,21 @@ impl BlockingHashJoin { semijoin: Semi::All, } => { let mut rhs_table = HashMap::new(); - rhs.execute(tx, &mut |v| { - rhs_table.insert(v.project(rhs_field), v); + rhs.execute(tx, metrics, &mut |v| { + rhs_table.insert(project(&v, rhs_field, &mut bytes_scanned), v); Ok(()) })?; - lhs.execute(tx, &mut |u| { - if let Some(v) = rhs_table.get(&u.project(lhs_field)) { + + // How many rows did we pull from the rhs? + n += rhs_table.len(); + + lhs.execute(tx, metrics, &mut |u| { + n += 1; + if let Some(v) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)) { f(u.clone().join(v.clone()))?; } Ok(()) - }) + })?; } Self { lhs, @@ -532,21 +642,23 @@ impl BlockingHashJoin { semijoin: Semi::Lhs, } => { let mut rhs_table = HashMap::new(); - rhs.execute(tx, &mut |v| { + rhs.execute(tx, metrics, &mut |v| { + n += 1; rhs_table - .entry(v.project(rhs_field)) + .entry(project(&v, rhs_field, &mut bytes_scanned)) .and_modify(|n| *n += 1) .or_insert(1); Ok(()) })?; - lhs.execute(tx, &mut |u| { - if let Some(n) = rhs_table.get(&u.project(lhs_field)).copied() { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + if let Some(n) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)).copied() { for _ in 0..n { f(u.clone())?; } } Ok(()) - }) + })?; } Self { lhs, @@ -557,8 +669,9 @@ impl BlockingHashJoin { semijoin: Semi::Rhs, } => { let mut rhs_table: HashMap> = HashMap::new(); - rhs.execute(tx, &mut |v| { - let key = v.project(rhs_field); + rhs.execute(tx, metrics, &mut |v| { + n += 1; + let key = project(&v, rhs_field, &mut bytes_scanned); if let Some(tuples) = rhs_table.get_mut(&key) { tuples.push(v); } else { @@ -566,14 +679,15 @@ impl BlockingHashJoin { } Ok(()) })?; - lhs.execute(tx, &mut |u| { - if let Some(rhs_tuples) = rhs_table.get(&u.project(lhs_field)) { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + if let Some(rhs_tuples) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)) { for v in rhs_tuples { f(v.clone())?; } } Ok(()) - }) + })?; } Self { lhs, @@ -584,8 +698,9 @@ impl BlockingHashJoin { semijoin: Semi::All, } => { let mut rhs_table: HashMap> = HashMap::new(); - rhs.execute(tx, &mut |v| { - let key = v.project(rhs_field); + rhs.execute(tx, metrics, &mut |v| { + n += 1; + let key = project(&v, rhs_field, &mut bytes_scanned); if let Some(tuples) = rhs_table.get_mut(&key) { tuples.push(v); } else { @@ -593,16 +708,20 @@ impl BlockingHashJoin { } Ok(()) })?; - lhs.execute(tx, &mut |u| { - if let Some(rhs_tuples) = rhs_table.get(&u.project(lhs_field)) { + lhs.execute(tx, metrics, &mut |u| { + n += 1; + if let Some(rhs_tuples) = rhs_table.get(&project(&u, lhs_field, &mut bytes_scanned)) { for v in rhs_tuples { f(u.clone().join(v.clone()))?; } } Ok(()) - }) + })?; } } + metrics.rows_scanned += n; + metrics.bytes_scanned += bytes_scanned; + Ok(()) } } @@ -618,19 +737,28 @@ impl BlockingNLJoin { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { let mut rhs = vec![]; - self.rhs.execute(tx, &mut |t| { - rhs.push(t); + self.rhs.execute(tx, metrics, &mut |v| { + rhs.push(v); Ok(()) })?; - self.lhs.execute(tx, &mut |u| { + + // How many rows did we pull from the rhs? + let mut n = rhs.len(); + + self.lhs.execute(tx, metrics, &mut |u| { + n += 1; for v in rhs.iter() { f(u.clone().join(v.clone()))?; } Ok(()) - }) + })?; + + metrics.rows_scanned += n; + Ok(()) } } @@ -644,13 +772,27 @@ impl PipelinedFilter { pub fn execute<'a, Tx: Datastore + DeltaStore>( &self, tx: &'a Tx, + metrics: &mut ExecutionMetrics, f: &mut dyn FnMut(Tuple<'a>) -> Result<()>, ) -> Result<()> { - self.input.execute(tx, &mut |t| { - if self.expr.eval_bool(&t) { + let mut n = 0; + let mut bytes_scanned = 0; + self.input.execute(tx, metrics, &mut |t| { + n += 1; + if self.expr.eval_bool_with_metrics(&t, &mut bytes_scanned) { f(t)?; } Ok(()) - }) + })?; + metrics.rows_scanned += n; + metrics.bytes_scanned += bytes_scanned; + Ok(()) } } + +/// A wrapper around [ProjectField] that increments a counter by the size of the projected value +fn project(row: &impl ProjectField, field: &TupleField, bytes_scanned: &mut usize) -> AlgebraicValue { + let value = row.project(field); + *bytes_scanned += value.size_of(); + value +} diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index f8c6455e1ee..5578650c7e4 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -26,11 +26,32 @@ pub enum ProjectName { } impl ProjectName { - /// What is the [TableId] for this projection? - pub fn table_id(&self) -> Option { + /// The [TableSchema] of the returned rows. + /// Note this expression returns rows from a relvar. + /// Hence it this method should never return [None]. + pub fn return_table(&self) -> Option<&TableSchema> { match self { - Self::None(input) => input.table_id(None), - Self::Some(input, var) => input.table_id(Some(var.as_ref())), + Self::None(input) => input.return_table(), + Self::Some(input, alias) => input.find_table_schema(alias), + } + } + + /// The [TableId] of the returned rows. + /// Note this expression returns rows from a relvar. + /// Hence it this method should never return [None]. + pub fn return_table_id(&self) -> Option { + match self { + Self::None(input) => input.return_table_id(), + Self::Some(input, alias) => input.find_table_id(alias), + } + } + + /// Iterate over the returned column names and types + pub fn for_each_return_field(&self, mut f: impl FnMut(&str, &AlgebraicType)) { + if let Some(schema) = self.return_table() { + for schema in schema.columns() { + f(&schema.col_name, &schema.col_type); + } } } } @@ -56,11 +77,37 @@ pub enum ProjectList { } impl ProjectList { - /// What is the [TableId] for this projection? - pub fn table_id(&self) -> Option { + /// Does this expression project a single relvar? + /// If so, we return it's [TableSchema]. + /// If not, it projects a list of columns, so we return [None]. + pub fn return_table(&self) -> Option<&TableSchema> { match self { + Self::Name(project) => project.return_table(), Self::List(..) => None, - Self::Name(proj) => proj.table_id(), + } + } + + /// Does this expression project a single relvar? + /// If so, we return it's [TableId]. + /// If not, it projects a list of columns, so we return [None]. + pub fn return_table_id(&self) -> Option { + match self { + Self::Name(project) => project.return_table_id(), + Self::List(..) => None, + } + } + + /// Iterate over the projected column names and types + pub fn for_each_return_field(&self, mut f: impl FnMut(&str, &AlgebraicType)) { + match self { + Self::Name(project) => { + project.for_each_return_field(f); + } + Self::List(_, fields) => { + for (name, FieldProject { ty, .. }) in fields { + f(name, ty); + } + } } } } @@ -108,23 +155,39 @@ impl RelExpr { } } - /// What is the [TableId] for this expression or relvar? - pub fn table_id(&self, var: Option<&str>) -> Option { - match (self, var) { - (Self::RelVar(Relvar { schema, .. }), None) => Some(schema.table_id), - (Self::RelVar(Relvar { schema, alias, .. }), Some(var)) if alias.as_ref() == var => Some(schema.table_id), - (Self::RelVar(Relvar { schema, .. }), Some(_)) => Some(schema.table_id), - (Self::Select(input, _), _) => input.table_id(var), - (Self::LeftDeepJoin(..) | Self::EqJoin(..), None) => None, - (Self::LeftDeepJoin(join) | Self::EqJoin(join, ..), Some(name)) => { - if join.rhs.alias.as_ref() == name { - Some(join.rhs.schema.table_id) - } else { - join.lhs.table_id(var) - } - } + /// Return the [TableSchema] for a relvar in the expression + pub fn find_table_schema(&self, alias: &str) -> Option<&TableSchema> { + match self { + Self::RelVar(relvar) if relvar.alias.as_ref() == alias => Some(&relvar.schema), + Self::Select(input, _) => input.find_table_schema(alias), + Self::EqJoin(LeftDeepJoin { rhs, .. }, ..) if rhs.alias.as_ref() == alias => Some(&rhs.schema), + Self::EqJoin(LeftDeepJoin { lhs, .. }, ..) => lhs.find_table_schema(alias), + Self::LeftDeepJoin(LeftDeepJoin { rhs, .. }) if rhs.alias.as_ref() == alias => Some(&rhs.schema), + Self::LeftDeepJoin(LeftDeepJoin { lhs, .. }) => lhs.find_table_schema(alias), + _ => None, } } + + /// Return the [TableId] for a relvar in the expression + pub fn find_table_id(&self, alias: &str) -> Option { + self.find_table_schema(alias).map(|schema| schema.table_id) + } + + /// Does this expression return a single relvar? + /// If so, return it's [TableSchema], otherwise return [None]. + pub fn return_table(&self) -> Option<&TableSchema> { + match self { + Self::RelVar(Relvar { schema, .. }) => Some(schema), + Self::Select(input, _) => input.return_table(), + _ => None, + } + } + + /// Does this expression return a single relvar? + /// If so, return it's [TableId], otherwise return [None]. + pub fn return_table_id(&self) -> Option { + self.return_table().map(|schema| schema.table_id) + } } /// A left deep binary cross product diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index a3752354890..a2d7712943c 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, ops::Deref}; use crate::statement::Statement; use check::{Relvars, TypingResult}; @@ -61,7 +61,7 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra Ok(Expr::Value(parse(v.into_string(), ty)?, ty.clone())) } (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => { - let table_type = vars.get(&table).ok_or_else(|| Unresolved::var(&table))?; + let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?; let ColumnSchema { col_pos, col_type, .. } = table_type .get_column_by_name(&field) .ok_or_else(|| Unresolved::var(&field))?; @@ -72,8 +72,9 @@ pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&Algebra })) } (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => { - let table_type = vars.get(&table).ok_or_else(|| Unresolved::var(&table))?; + let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?; let ColumnSchema { col_pos, col_type, .. } = table_type + .as_ref() .get_column_by_name(&field) .ok_or_else(|| Unresolved::var(&field))?; if col_type != ty { diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 9293a2015f6..260ae9a94d1 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -1,18 +1,22 @@ use std::sync::Arc; -use spacetimedb_lib::{AlgebraicType, AlgebraicValue}; -use spacetimedb_primitives::ColId; +use spacetimedb_lib::{st_var::StVarValue, AlgebraicType, AlgebraicValue, ProductValue}; +use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_schema::schema::{ColumnSchema, TableSchema}; use spacetimedb_sql_parser::{ ast::{ sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate}, - SqlIdent, SqlLiteral, + BinOp, SqlIdent, SqlLiteral, }, parser::sql::parse_sql, }; use thiserror::Error; -use crate::{check::Relvars, expr::ProjectList}; +use crate::{ + check::Relvars, + errors::InvalidLiteral, + expr::{FieldProject, ProjectList, RelExpr, Relvar}, +}; use super::{ check::{SchemaView, TypeChecker, TypingResult}, @@ -23,29 +27,49 @@ use super::{ pub enum Statement { Select(ProjectList), + DML(DML), +} + +pub enum DML { Insert(TableInsert), Update(TableUpdate), Delete(TableDelete), - Set(SetVar), - Show(ShowVar), } -/// A resolved row of literal values for an insert -pub type Row = Box<[AlgebraicValue]>; +impl DML { + /// Returns the schema of the table on which this mutation applies + pub fn table_schema(&self) -> &TableSchema { + match self { + Self::Insert(insert) => &insert.table, + Self::Delete(delete) => &delete.table, + Self::Update(update) => &update.table, + } + } + + /// Returns the id of the table on which this mutation applies + pub fn table_id(&self) -> TableId { + self.table_schema().table_id + } + + /// Returns the name of the table on which this mutation applies + pub fn table_name(&self) -> Box { + self.table_schema().table_name.clone() + } +} pub struct TableInsert { - pub into: Arc, - pub rows: Box<[Row]>, + pub table: Arc, + pub rows: Box<[ProductValue]>, } pub struct TableDelete { - pub from: Arc, - pub expr: Option, + pub table: Arc, + pub filter: Option, } pub struct TableUpdate { - pub schema: Arc, - pub values: Box<[(ColId, AlgebraicValue)]>, + pub table: Arc, + pub columns: Box<[(ColId, AlgebraicValue)]>, pub filter: Option, } @@ -92,10 +116,13 @@ pub fn type_insert(insert: SqlInsert, tx: &impl SchemaView) -> TypingResult { values.push(AlgebraicValue::Bool(v)); @@ -114,11 +141,11 @@ pub fn type_insert(insert: SqlInsert, tx: &impl SchemaView) -> TypingResult TypingResult TypingResult TypingResult bool { var == VAR_ROW_LIMIT || var == VAR_SLOW_QUERY || var == VAR_SLOW_UPDATE || var == VAR_SLOW_SUB } -pub fn type_set(set: SqlSet) -> TypingResult { - let SqlSet(SqlIdent(name), lit) = set; - if !is_var_valid(&name) { +const ST_VAR_NAME: &str = "st_var"; +const VALUE_COLUMN: &str = "value"; + +/// The concept of `SET` only exists in the ast. +/// We translate it here to an `INSERT` on the `st_var` system table. +/// That is: +/// +/// ```sql +/// SET var TO ... +/// ``` +/// +/// is rewritten as +/// +/// ```sql +/// INSERT INTO st_var (name, value) VALUES ('var', ...) +/// ``` +pub fn type_and_rewrite_set(set: SqlSet, tx: &impl SchemaView) -> TypingResult { + let SqlSet(SqlIdent(var_name), lit) = set; + if !is_var_valid(&var_name) { return Err(InvalidVar { - name: name.into_string(), + name: var_name.into_string(), } .into()); } + match lit { SqlLiteral::Bool(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::Bool).into()), SqlLiteral::Str(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::String).into()), SqlLiteral::Hex(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::bytes()).into()), - SqlLiteral::Num(n) => Ok(SetVar { - name: name.into_string(), - value: parse(n.into_string(), &AlgebraicType::U64)?, - }), + SqlLiteral::Num(n) => { + let table = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?; + let var_name = AlgebraicValue::String(var_name); + let sum_value = StVarValue::try_from_primitive(parse(n.clone().into_string(), &AlgebraicType::U64)?) + .map_err(|_| InvalidLiteral::new(n.into_string(), &AlgebraicType::U64))? + .into(); + Ok(TableInsert { + table, + rows: Box::new([ProductValue::from_iter([var_name, sum_value])]), + }) + } } } -pub fn type_show(show: SqlShow) -> TypingResult { - let SqlShow(SqlIdent(name)) = show; - if !is_var_valid(&name) { +/// The concept of `SHOW` only exists in the ast. +/// We translate it here to a `SELECT` on the `st_var` system table. +/// That is: +/// +/// ```sql +/// SHOW var +/// ``` +/// +/// is rewritten as +/// +/// ```sql +/// SELECT value FROM st_var WHERE name = 'var' +/// ``` +pub fn type_and_rewrite_show(show: SqlShow, tx: &impl SchemaView) -> TypingResult { + let SqlShow(SqlIdent(var_name)) = show; + if !is_var_valid(&var_name) { return Err(InvalidVar { - name: name.into_string(), + name: var_name.into_string(), } .into()); } - Ok(ShowVar { - name: name.into_string(), - }) + + let table_schema = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?; + + let value_col_ty = table_schema + .as_ref() + .get_column(1) + .map(|ColumnSchema { col_type, .. }| col_type) + .ok_or_else(|| Unresolved::field(ST_VAR_NAME, VALUE_COLUMN))?; + + // ------------------------------------------- + // SELECT value FROM st_var WHERE name = 'var' + // ^^^^ + // ------------------------------------------- + let var_name_field = Expr::Field(FieldProject { + table: ST_VAR_NAME.into(), + // TODO: Avoid hard coding the field position. + // See `StVarFields` for the schema of `st_var`. + field: 0, + ty: AlgebraicType::String, + }); + + // ------------------------------------------- + // SELECT value FROM st_var WHERE name = 'var' + // ^^^ + // ------------------------------------------- + let var_name_value = Expr::Value(AlgebraicValue::String(var_name), AlgebraicType::String); + + // ------------------------------------------- + // SELECT value FROM st_var WHERE name = 'var' + // ^^^^^ + // ------------------------------------------- + let column_list = vec![( + VALUE_COLUMN.into(), + FieldProject { + table: ST_VAR_NAME.into(), + // TODO: Avoid hard coding the field position. + // See `StVarFields` for the schema of `st_var`. + field: 1, + ty: value_col_ty.clone(), + }, + )]; + + // ------------------------------------------- + // SELECT value FROM st_var WHERE name = 'var' + // ^^^^^^ + // ------------------------------------------- + let relvar = RelExpr::RelVar(Relvar { + schema: table_schema, + alias: ST_VAR_NAME.into(), + delta: None, + }); + + let filter = Expr::BinOp( + // ------------------------------------------- + // SELECT value FROM st_var WHERE name = 'var' + // ^^^ + // ------------------------------------------- + BinOp::Eq, + Box::new(var_name_field), + Box::new(var_name_value), + ); + + Ok(ProjectList::List( + RelExpr::Select(Box::new(relvar), filter), + column_list, + )) } /// Type-checker for regular `SQL` queries @@ -266,14 +401,14 @@ impl TypeChecker for SqlChecker { } } -fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { +pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { match parse_sql(sql)? { - SqlAst::Insert(insert) => Ok(Statement::Insert(type_insert(insert, tx)?)), - SqlAst::Delete(delete) => Ok(Statement::Delete(type_delete(delete, tx)?)), - SqlAst::Update(update) => Ok(Statement::Update(type_update(update, tx)?)), SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)), - SqlAst::Set(set) => Ok(Statement::Set(type_set(set)?)), - SqlAst::Show(show) => Ok(Statement::Show(type_show(show)?)), + SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))), + SqlAst::Delete(delete) => Ok(Statement::DML(DML::Delete(type_delete(delete, tx)?))), + SqlAst::Update(update) => Ok(Statement::DML(DML::Update(type_update(update, tx)?))), + SqlAst::Set(set) => Ok(Statement::DML(DML::Insert(type_and_rewrite_set(set, tx)?))), + SqlAst::Show(show) => Ok(Statement::Select(type_and_rewrite_show(show, tx)?)), } } diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index ca09c6df0ef..fa9e48774a3 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -10,10 +10,12 @@ pub mod address; pub mod db; pub mod error; pub mod identity; +pub mod metrics; pub mod operator; pub mod query; pub mod relation; pub mod scheduler; +pub mod st_var; pub mod version; pub mod type_def { diff --git a/crates/lib/src/metrics.rs b/crates/lib/src/metrics.rs new file mode 100644 index 00000000000..c99bc3bd9e2 --- /dev/null +++ b/crates/lib/src/metrics.rs @@ -0,0 +1,52 @@ +/// Metrics collected during the course of a transaction +#[derive(Default)] +pub struct ExecutionMetrics { + /// How many times is an index probed? + /// + /// Note that a single btree scan may return many values, + /// but will only result in a single index seek. + pub index_seeks: usize, + /// How many rows are iterated over? + /// + /// It is independent of the number of rows returned. + /// A query for example may return a single row, + /// but if it scans the entire table to find that row, + /// this metric will reflect that. + pub rows_scanned: usize, + /// How many bytes are read? + /// + /// This metric is incremented anytime we dereference a `RowPointer`. + /// + /// For reducers this happens at the WASM boundary, + /// when serializing entire rows via the BSATN encoding. + /// + /// In addition to the same BSATN serialization of the output rows, + /// queries will dereference a `RowPointer` for column projections. + /// Such is the case for fiters as well as index and hash joins. + /// + /// One place where this metric is not tracked is index scans. + /// Specifically the key comparisons that occur during the scan. + pub bytes_scanned: usize, + /// How many bytes are written? + /// + /// Note, this is the same as bytes inserted, + /// because deletes just update a free list in the datastore. + /// They don't actually write or clear page memory. + pub bytes_written: usize, + /// How many bytes did we send to clients? + /// + /// This is not necessarily the same as bytes scanned, + /// since a single query may send bytes to multiple clients. + /// + /// In general, these are BSATN bytes, but JSON is also possible. + pub bytes_sent_to_clients: usize, +} + +impl ExecutionMetrics { + pub fn merge(&mut self, with: ExecutionMetrics) { + self.index_seeks += with.index_seeks; + self.rows_scanned += with.rows_scanned; + self.bytes_scanned += with.bytes_scanned; + self.bytes_written += with.bytes_written; + } +} diff --git a/crates/lib/src/st_var.rs b/crates/lib/src/st_var.rs new file mode 100644 index 00000000000..a1af179da2f --- /dev/null +++ b/crates/lib/src/st_var.rs @@ -0,0 +1,143 @@ +use derive_more::From; +use spacetimedb_sats::{AlgebraicValue, SpacetimeType, SumValue}; + +/// The value of a system variable in `st_var`. +/// Defined here because it is used in both the datastore and query. +#[derive(Debug, Clone, From, SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub enum StVarValue { + Bool(bool), + I8(i8), + U8(u8), + I16(i16), + U16(u16), + I32(i32), + U32(u32), + I64(i64), + U64(u64), + I128(i128), + U128(u128), + // No support for u/i256 added here as it seems unlikely to be useful. + F32(f32), + F64(f64), + String(Box), +} + +impl StVarValue { + pub fn try_from_primitive(value: AlgebraicValue) -> Result { + match value { + AlgebraicValue::Bool(v) => Ok(StVarValue::Bool(v)), + AlgebraicValue::I8(v) => Ok(StVarValue::I8(v)), + AlgebraicValue::U8(v) => Ok(StVarValue::U8(v)), + AlgebraicValue::I16(v) => Ok(StVarValue::I16(v)), + AlgebraicValue::U16(v) => Ok(StVarValue::U16(v)), + AlgebraicValue::I32(v) => Ok(StVarValue::I32(v)), + AlgebraicValue::U32(v) => Ok(StVarValue::U32(v)), + AlgebraicValue::I64(v) => Ok(StVarValue::I64(v)), + AlgebraicValue::U64(v) => Ok(StVarValue::U64(v)), + AlgebraicValue::I128(v) => Ok(StVarValue::I128(v.0)), + AlgebraicValue::U128(v) => Ok(StVarValue::U128(v.0)), + AlgebraicValue::F32(v) => Ok(StVarValue::F32(v.into_inner())), + AlgebraicValue::F64(v) => Ok(StVarValue::F64(v.into_inner())), + AlgebraicValue::String(v) => Ok(StVarValue::String(v)), + _ => Err(value), + } + } + + pub fn try_from_sum(value: AlgebraicValue) -> Result { + value.into_sum()?.try_into() + } +} + +impl TryFrom for StVarValue { + type Error = AlgebraicValue; + + fn try_from(sum: SumValue) -> Result { + match sum.tag { + 0 => Ok(StVarValue::Bool(sum.value.into_bool()?)), + 1 => Ok(StVarValue::I8(sum.value.into_i8()?)), + 2 => Ok(StVarValue::U8(sum.value.into_u8()?)), + 3 => Ok(StVarValue::I16(sum.value.into_i16()?)), + 4 => Ok(StVarValue::U16(sum.value.into_u16()?)), + 5 => Ok(StVarValue::I32(sum.value.into_i32()?)), + 6 => Ok(StVarValue::U32(sum.value.into_u32()?)), + 7 => Ok(StVarValue::I64(sum.value.into_i64()?)), + 8 => Ok(StVarValue::U64(sum.value.into_u64()?)), + 9 => Ok(StVarValue::I128(sum.value.into_i128()?.0)), + 10 => Ok(StVarValue::U128(sum.value.into_u128()?.0)), + 11 => Ok(StVarValue::F32(sum.value.into_f32()?.into_inner())), + 12 => Ok(StVarValue::F64(sum.value.into_f64()?.into_inner())), + 13 => Ok(StVarValue::String(sum.value.into_string()?)), + _ => Err(*sum.value), + } + } +} + +impl From for AlgebraicValue { + fn from(value: StVarValue) -> Self { + AlgebraicValue::Sum(value.into()) + } +} + +impl From for SumValue { + fn from(value: StVarValue) -> Self { + match value { + StVarValue::Bool(v) => SumValue { + tag: 0, + value: Box::new(AlgebraicValue::Bool(v)), + }, + StVarValue::I8(v) => SumValue { + tag: 1, + value: Box::new(AlgebraicValue::I8(v)), + }, + StVarValue::U8(v) => SumValue { + tag: 2, + value: Box::new(AlgebraicValue::U8(v)), + }, + StVarValue::I16(v) => SumValue { + tag: 3, + value: Box::new(AlgebraicValue::I16(v)), + }, + StVarValue::U16(v) => SumValue { + tag: 4, + value: Box::new(AlgebraicValue::U16(v)), + }, + StVarValue::I32(v) => SumValue { + tag: 5, + value: Box::new(AlgebraicValue::I32(v)), + }, + StVarValue::U32(v) => SumValue { + tag: 6, + value: Box::new(AlgebraicValue::U32(v)), + }, + StVarValue::I64(v) => SumValue { + tag: 7, + value: Box::new(AlgebraicValue::I64(v)), + }, + StVarValue::U64(v) => SumValue { + tag: 8, + value: Box::new(AlgebraicValue::U64(v)), + }, + StVarValue::I128(v) => SumValue { + tag: 9, + value: Box::new(AlgebraicValue::I128(v.into())), + }, + StVarValue::U128(v) => SumValue { + tag: 10, + value: Box::new(AlgebraicValue::U128(v.into())), + }, + StVarValue::F32(v) => SumValue { + tag: 11, + value: Box::new(AlgebraicValue::F32(v.into())), + }, + StVarValue::F64(v) => SumValue { + tag: 12, + value: Box::new(AlgebraicValue::F64(v.into())), + }, + StVarValue::String(v) => SumValue { + tag: 13, + value: Box::new(AlgebraicValue::String(v)), + }, + } + } +} diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs index bc1e13329d0..2d00d72f4b7 100644 --- a/crates/physical-plan/src/compile.rs +++ b/crates/physical-plan/src/compile.rs @@ -2,12 +2,12 @@ use std::collections::HashMap; -use crate::plan::{ - HashJoin, Label, PhysicalCtx, PhysicalExpr, PhysicalPlan, ProjectListPlan, ProjectPlan, Semi, TupleField, -}; +use crate::dml::{DeletePlan, MutationPlan, UpdatePlan}; +use crate::plan::{HashJoin, Label, PhysicalExpr, PhysicalPlan, ProjectListPlan, ProjectPlan, Semi, TupleField}; +use crate::{PhysicalCtx, PlanCtx}; use spacetimedb_expr::expr::{Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, RelExpr, Relvar}; -use spacetimedb_expr::statement::Statement; +use spacetimedb_expr::statement::{Statement, DML}; use spacetimedb_expr::StatementCtx; pub trait VarLabel { @@ -34,7 +34,7 @@ fn compile_project_list(var: &mut impl VarLabel, expr: ProjectList) -> ProjectLi compile_rel_expr(var, proj), fields .into_iter() - .map(|(alias, expr)| (alias, compile_field_project(var, expr))) + .map(|(_, expr)| compile_field_project(var, expr)) .collect(), ), } @@ -116,67 +116,70 @@ fn compile_rel_expr(var: &mut impl VarLabel, ast: RelExpr) -> PhysicalPlan { } } -/// Compile a logical subscribe expression -pub fn compile_project_plan(project: ProjectName) -> ProjectPlan { - struct Interner { - next: usize, - names: HashMap, +/// Generates unique ids for named entities in a query plan +#[derive(Default)] +struct NamesToIds { + next_id: usize, + map: HashMap, +} + +impl NamesToIds { + fn into_map(self) -> HashMap { + self.map } - impl VarLabel for Interner { - fn label(&mut self, name: &str) -> Label { - if let Some(id) = self.names.get(name) { - return Label(*id); - } - self.next += 1; - self.names.insert(name.to_owned(), self.next); - self.next.into() +} + +impl VarLabel for NamesToIds { + fn label(&mut self, name: &str) -> Label { + if let Some(id) = self.map.get(name) { + return Label(*id); } + self.next_id += 1; + self.map.insert(name.to_owned(), self.next_id); + self.next_id.into() } - compile_project_name( - &mut Interner { - next: 0, - names: HashMap::new(), - }, - project, - ) } -/// Compile a SQL statement into a physical plan. -/// -/// The input [Statement] is assumed to be valid so the lowering is not expected to fail. -/// -/// **NOTE:** It does not optimize the plan. -pub fn compile(ast: StatementCtx<'_>) -> PhysicalCtx<'_> { - struct Interner { - next: usize, - names: HashMap, - } - impl VarLabel for Interner { - fn label(&mut self, name: &str) -> Label { - if let Some(id) = self.names.get(name) { - return Label(*id); - } - self.next += 1; - self.names.insert(name.to_owned(), self.next); - self.next.into() - } +/// Converts a logical selection into a physical plan. +/// Note, this utility is specific to subscriptions, +/// in that it does not support explicit column projections. +pub fn compile_select(project: ProjectName) -> ProjectPlan { + compile_project_name(&mut NamesToIds::default(), project) +} + +/// Converts a logical selection into a physical plan. +/// Note, this utility is applicable to a generic selections. +/// In particular, it supports explicit column projections. +pub fn compile_select_list(project: ProjectList) -> ProjectListPlan { + compile_select_list_raw(&mut NamesToIds::default(), project) +} + +pub fn compile_select_list_raw(var: &mut impl VarLabel, project: ProjectList) -> ProjectListPlan { + compile_project_list(var, project) +} + +/// Converts a logical DML statement into a physical plan, +/// but does not optimize it. +pub fn compile_dml_plan(stmt: DML) -> MutationPlan { + match stmt { + DML::Insert(insert) => MutationPlan::Insert(insert.into()), + DML::Delete(delete) => MutationPlan::Delete(DeletePlan::compile(delete)), + DML::Update(update) => MutationPlan::Update(UpdatePlan::compile(update)), } - let mut var = Interner { - next: 0, - names: HashMap::new(), - }; +} + +pub fn compile(ast: StatementCtx<'_>) -> PhysicalCtx<'_> { + let mut vars = NamesToIds::default(); let plan = match ast.statement { - Statement::Select(expr) => compile_project_list(&mut var, expr), - _ => { - unreachable!("Only `SELECT` is implemented") - } + Statement::Select(project) => PlanCtx::ProjectList(compile_select_list_raw(&mut vars, project)), + Statement::DML(stmt) => PlanCtx::DML(compile_dml_plan(stmt)), }; PhysicalCtx { plan, sql: ast.sql, - vars: var.names, + vars: vars.into_map(), source: ast.source, - planning_time: ast.planning_time, + planning_time: None, } } diff --git a/crates/physical-plan/src/dml.rs b/crates/physical-plan/src/dml.rs new file mode 100644 index 00000000000..332af40aace --- /dev/null +++ b/crates/physical-plan/src/dml.rs @@ -0,0 +1,115 @@ +use std::sync::Arc; + +use spacetimedb_expr::{ + expr::{ProjectName, RelExpr, Relvar}, + statement::{TableDelete, TableInsert, TableUpdate}, +}; +use spacetimedb_lib::{AlgebraicValue, ProductValue}; +use spacetimedb_primitives::ColId; +use spacetimedb_schema::schema::TableSchema; + +use crate::{compile::compile_select, plan::ProjectPlan}; + +/// A plan for mutating a table in the database +#[derive(Debug)] +pub enum MutationPlan { + Insert(InsertPlan), + Delete(DeletePlan), + Update(UpdatePlan), +} + +impl MutationPlan { + /// Optimizes the filters in updates and deletes + pub fn optimize(self) -> Self { + match self { + Self::Insert(..) => self, + Self::Delete(plan) => Self::Delete(plan.optimize()), + Self::Update(plan) => Self::Update(plan.optimize()), + } + } +} + +/// A plan for inserting rows into a table +#[derive(Debug)] +pub struct InsertPlan { + pub table: Arc, + pub rows: Vec, +} + +impl From for InsertPlan { + fn from(insert: TableInsert) -> Self { + let TableInsert { table, rows } = insert; + let rows = rows.into_vec(); + Self { table, rows } + } +} + +/// A plan for deleting rows from a table +#[derive(Debug)] +pub struct DeletePlan { + pub table: Arc, + pub filter: ProjectPlan, +} + +impl DeletePlan { + /// Optimize the filter part of the delete + fn optimize(self) -> Self { + let Self { table, filter } = self; + let filter = filter.optimize(); + Self { table, filter } + } + + /// Logical to physical conversion + pub(crate) fn compile(delete: TableDelete) -> Self { + let TableDelete { table, filter } = delete; + let schema = table.clone(); + let alias = table.table_name.clone(); + let relvar = RelExpr::RelVar(Relvar { + schema, + alias, + delta: None, + }); + let project = match filter { + None => ProjectName::None(relvar), + Some(expr) => ProjectName::None(RelExpr::Select(Box::new(relvar), expr)), + }; + let filter = compile_select(project); + Self { table, filter } + } +} + +/// A plan for updating rows in a table +#[derive(Debug)] +pub struct UpdatePlan { + pub table: Arc, + pub columns: Vec<(ColId, AlgebraicValue)>, + pub filter: ProjectPlan, +} + +impl UpdatePlan { + /// Optimize the filter part of the update + fn optimize(self) -> Self { + let Self { table, columns, filter } = self; + let filter = filter.optimize(); + Self { columns, table, filter } + } + + /// Logical to physical conversion + pub(crate) fn compile(update: TableUpdate) -> Self { + let TableUpdate { table, columns, filter } = update; + let schema = table.clone(); + let alias = table.table_name.clone(); + let relvar = RelExpr::RelVar(Relvar { + schema, + alias, + delta: None, + }); + let project = match filter { + None => ProjectName::None(relvar), + Some(expr) => ProjectName::None(RelExpr::Select(Box::new(relvar), expr)), + }; + let filter = compile_select(project); + let columns = columns.into_vec(); + Self { columns, table, filter } + } +} diff --git a/crates/physical-plan/src/lib.rs b/crates/physical-plan/src/lib.rs index 1c0251d523d..6c510ea80dc 100644 --- a/crates/physical-plan/src/lib.rs +++ b/crates/physical-plan/src/lib.rs @@ -1,4 +1,45 @@ +use crate::dml::MutationPlan; +use crate::plan::ProjectListPlan; +use spacetimedb_expr::StatementSource; +use std::collections::HashMap; + pub mod compile; +pub mod dml; pub mod plan; pub mod printer; pub mod rules; + +#[derive(Debug)] +pub enum PlanCtx { + ProjectList(ProjectListPlan), + DML(MutationPlan), +} + +impl PlanCtx { + pub(crate) fn optimize(self) -> PlanCtx { + match self { + Self::ProjectList(plan) => Self::ProjectList(plan.optimize()), + Self::DML(plan) => Self::DML(plan.optimize()), + } + } +} + +/// A physical context for the result of a query compilation. +#[derive(Debug)] +pub struct PhysicalCtx<'a> { + pub plan: PlanCtx, + pub sql: &'a str, + // A map from table names to their labels + pub vars: HashMap, + pub source: StatementSource, + pub planning_time: Option, +} + +impl PhysicalCtx<'_> { + pub fn optimize(self) -> Self { + Self { + plan: self.plan.optimize(), + ..self + } + } +} diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index c8fbfa1250a..281cd54c482 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1,13 +1,11 @@ use derive_more::From; -use std::collections::HashMap; use std::{ borrow::Cow, ops::{Bound, Deref, DerefMut}, sync::Arc, }; -use spacetimedb_expr::StatementSource; -use spacetimedb_lib::{query::Delta, AlgebraicValue, ProductValue}; +use spacetimedb_lib::{query::Delta, sats::size_of::SizeOf, AlgebraicValue, ProductValue}; use spacetimedb_primitives::{ColId, ColSet, IndexId}; use spacetimedb_schema::schema::{IndexSchema, TableSchema}; use spacetimedb_sql_parser::ast::{BinOp, LogOp}; @@ -100,7 +98,18 @@ impl ProjectPlan { #[derive(Debug)] pub enum ProjectListPlan { Name(ProjectPlan), - List(PhysicalPlan, Vec<(Box, TupleField)>), + List(PhysicalPlan, Vec), +} + +impl Deref for ProjectListPlan { + type Target = PhysicalPlan; + + fn deref(&self) -> &Self::Target { + match self { + Self::Name(plan) => plan, + Self::List(plan, ..) => plan, + } + } } impl ProjectListPlan { @@ -108,13 +117,7 @@ impl ProjectListPlan { match self { Self::Name(plan) => Self::Name(plan.optimize()), Self::List(plan, fields) => Self::List( - plan.optimize( - fields - .iter() - .map(|(_, TupleField { label, .. })| label) - .copied() - .collect(), - ), + plan.optimize(fields.iter().map(|TupleField { label, .. }| label).copied().collect()), fields, ), } @@ -561,7 +564,7 @@ impl PhysicalPlan { pub(crate) fn returns_distinct_values(&self, label: &Label, cols: &ColSet) -> bool { match self { // Is there a unique constraint for these cols? - Self::TableScan(schema, var, _) => var == label && schema.is_unique(cols), + Self::TableScan(schema, var, _) => var == label && schema.as_ref().is_unique(cols), // Is there a unique constraint for these cols + the index cols? Self::IxScan( IxScan { @@ -573,7 +576,7 @@ impl PhysicalPlan { var, ) => { var == label - && schema.is_unique(&ColSet::from_iter( + && schema.as_ref().is_unique(&ColSet::from_iter( cols.iter() .chain(prefix.iter().map(|(col_id, _)| *col_id)) .chain(vec![*col]), @@ -606,7 +609,7 @@ impl PhysicalPlan { _, ) => { lhs.returns_distinct_values(lhs_label, &ColSet::from(ColId(*lhs_field_pos as u16))) - && rhs.is_unique(cols) + && rhs.as_ref().is_unique(cols) } // If the table in question is on the lhs, // and if the lhs returns distinct values, @@ -903,8 +906,21 @@ impl PhysicalExpr { self.eval(row).as_bool().copied().unwrap_or(false) } + /// Evaluate this boolean expression over `row` + pub fn eval_bool_with_metrics(&self, row: &impl ProjectField, bytes_scanned: &mut usize) -> bool { + self.eval_with_metrics(row, bytes_scanned) + .as_bool() + .copied() + .unwrap_or(false) + } + /// Evaluate this expression over `row` fn eval(&self, row: &impl ProjectField) -> Cow<'_, AlgebraicValue> { + self.eval_with_metrics(row, &mut 0) + } + + /// Evaluate this expression over `row` + fn eval_with_metrics(&self, row: &impl ProjectField, bytes_scanned: &mut usize) -> Cow<'_, AlgebraicValue> { fn eval_bin_op(op: BinOp, a: &AlgebraicValue, b: &AlgebraicValue) -> bool { match op { BinOp::Eq => a == b, @@ -917,20 +933,28 @@ impl PhysicalExpr { } let into = |b| Cow::Owned(AlgebraicValue::Bool(b)); match self { - Self::BinOp(op, a, b) => into(eval_bin_op(*op, &a.eval(row), &b.eval(row))), + Self::BinOp(op, a, b) => into(eval_bin_op( + *op, + &a.eval_with_metrics(row, bytes_scanned), + &b.eval_with_metrics(row, bytes_scanned), + )), Self::LogOp(LogOp::And, exprs) => into( exprs .iter() // ALL is equivalent to AND - .all(|expr| expr.eval_bool(row)), + .all(|expr| expr.eval_bool_with_metrics(row, bytes_scanned)), ), Self::LogOp(LogOp::Or, exprs) => into( exprs .iter() // ANY is equivalent to OR - .any(|expr| expr.eval_bool(row)), + .any(|expr| expr.eval_bool_with_metrics(row, bytes_scanned)), ), - Self::Field(field) => Cow::Owned(row.project(field)), + Self::Field(field) => { + let value = row.project(field); + *bytes_scanned += value.size_of(); + Cow::Owned(value) + } Self::Value(v) => Cow::Borrowed(v), } } @@ -955,30 +979,10 @@ impl PhysicalExpr { } } -/// A physical context for the result of a query compilation. -#[derive(Debug)] -pub struct PhysicalCtx<'a> { - pub plan: ProjectListPlan, - pub sql: &'a str, - // A map from table names to their labels - pub vars: HashMap, - pub source: StatementSource, - pub planning_time: Option, -} - -impl<'a> PhysicalCtx<'a> { - pub fn optimize(self) -> Self { - Self { - plan: self.plan.optimize(), - ..self - } - } -} - pub mod tests_utils { - use super::*; use crate::compile::compile; use crate::printer::{Explain, ExplainOptions}; + use crate::PhysicalCtx; use expect_test::Expect; use spacetimedb_expr::check::{compile_sql_sub, SchemaView}; use spacetimedb_expr::statement::compile_sql_stmt; @@ -1017,24 +1021,18 @@ pub mod tests_utils { mod tests { use super::*; - use crate::compile::compile_project_plan; - use crate::plan::TupleField; use crate::printer::ExplainOptions; use expect_test::{expect, Expect}; - use pretty_assertions::assert_eq; - use spacetimedb_expr::check::{parse_and_type_sub, SchemaView}; + use spacetimedb_expr::check::SchemaView; use spacetimedb_lib::{ db::auth::{StAccess, StTableType}, - AlgebraicType, AlgebraicValue, + AlgebraicType, }; use spacetimedb_primitives::{ColId, ColList, ColSet, TableId}; use spacetimedb_schema::{ def::{BTreeAlgorithm, ConstraintData, IndexAlgorithm, UniqueConstraintData}, schema::{ColumnSchema, ConstraintSchema, IndexSchema, TableSchema}, }; - use spacetimedb_sql_parser::ast::BinOp; - - use super::{PhysicalExpr, ProjectPlan}; struct SchemaViewer { schemas: Vec>, @@ -1126,6 +1124,14 @@ mod tests { ) } + fn check_sub(db: &SchemaViewer, sql: &str, expect: Expect) { + tests_utils::check_sub(db, db.options, sql, expect); + } + + fn check_query(db: &SchemaViewer, sql: &str, expect: Expect) { + tests_utils::check_query(db, db.options, sql, expect); + } + /// No rewrites applied to a simple table scan #[test] fn table_scan_noop() { @@ -1244,18 +1250,15 @@ Seq Scan on t let db = SchemaViewer::new(vec![u.clone(), l.clone(), b.clone()]).optimize(true); - let sql = " + check_sub( + &db, + " select b.* from u join l as p on u.entity_id = p.entity_id join l as q on p.chunk = q.chunk join b on q.entity_id = b.entity_id - where u.identity = 5 - "; - - check_sub( - &db, - sql, + where u.identity = 5", expect![[r#" Index Join: Rhs on b -> Index Join: Rhs on q @@ -1270,111 +1273,6 @@ Index Join: Rhs on b Join Cond: (q.entity_id = b.entity_id) Output: b.entity_id, b.misc"#]], ); - - let lp = parse_and_type_sub(sql, &db).unwrap(); - let pp = compile_project_plan(lp).optimize(); - - // Plan: - // rx - // / \ - // rx b - // / \ - // rx l - // / \ - // ix(u) l - let plan = match pp { - ProjectPlan::None(plan) => plan, - proj => panic!("unexpected project: {:#?}", proj), - }; - - // Plan: - // rx - // / \ - // rx b - // / \ - // rx l - // / \ - // ix(u) l - let plan = match plan { - PhysicalPlan::IxJoin( - IxJoin { - lhs, - rhs, - rhs_field: ColId(0), - unique: true, - lhs_field: TupleField { field_pos: 0, .. }, - .. - }, - Semi::Rhs, - ) => { - assert_eq!(rhs.table_id, b_id); - *lhs - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: - // rx - // / \ - // rx l - // / \ - // ix(u) l - let plan = match plan { - PhysicalPlan::IxJoin( - IxJoin { - lhs, - rhs, - rhs_field: ColId(1), - unique: false, - lhs_field: TupleField { field_pos: 1, .. }, - .. - }, - Semi::Rhs, - ) => { - assert_eq!(rhs.table_id, l_id); - *lhs - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: - // rx - // / \ - // ix(u) l - let plan = match plan { - PhysicalPlan::IxJoin( - IxJoin { - lhs, - rhs, - rhs_field: ColId(0), - unique: true, - lhs_field: TupleField { field_pos: 1, .. }, - .. - }, - Semi::Rhs, - ) => { - assert_eq!(rhs.table_id, l_id); - *lhs - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: ix(u) - match plan { - PhysicalPlan::IxScan( - IxScan { - schema, - prefix, - arg: Sarg::Eq(ColId(0), AlgebraicValue::U64(5)), - .. - }, - _, - ) => { - assert!(prefix.is_empty()); - assert_eq!(schema.table_id, u_id); - } - plan => panic!("unexpected plan: {:#?}", plan), - } } /// Given the following operator notation: @@ -1452,19 +1350,16 @@ Index Join: Rhs on b let db = SchemaViewer::new(vec![m.clone(), w.clone(), p.clone()]).optimize(false); - let sql = " + check_sub( + &db, + " select p.* from m join m as n on m.manager = n.manager join w as u on n.employee = u.employee join w as v on u.project = v.project join p on p.id = v.project - where 5 = m.employee and 5 = v.employee - "; - - check_sub( - &db, - sql, + where 5 = m.employee and 5 = v.employee", expect![[r#" Hash Join -> Hash Join @@ -1486,154 +1381,6 @@ Hash Join Filter: (m.employee = U64(5) AND v.employee = U64(5)) Output: p.id, p.name"#]], ); - - let lp = parse_and_type_sub(sql, &db).unwrap(); - let pp = compile_project_plan(lp).optimize(); - - // Plan: - // rx - // / \ - // rj p - // / \ - // rx ix(w) - // / \ - // rx w - // / \ - // ix(m) m - let plan = match pp { - ProjectPlan::None(plan) => plan, - proj => panic!("unexpected project: {:#?}", proj), - }; - - // Plan: - // rx - // / \ - // rj p - // / \ - // rx ix(w) - // / \ - // rx w - // / \ - // ix(m) m - let plan = match plan { - PhysicalPlan::IxJoin( - IxJoin { - lhs, - rhs, - rhs_field: ColId(0), - unique: true, - lhs_field: TupleField { field_pos: 1, .. }, - .. - }, - Semi::Rhs, - ) => { - assert_eq!(rhs.table_id, p_id); - *lhs - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: - // rj - // / \ - // rx ix(w) - // / \ - // rx w - // / \ - // ix(m) m - let (rhs, lhs) = match plan { - PhysicalPlan::HashJoin( - HashJoin { - lhs, - rhs, - lhs_field: TupleField { field_pos: 1, .. }, - rhs_field: TupleField { field_pos: 1, .. }, - unique: true, - }, - Semi::Rhs, - ) => (*rhs, *lhs), - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: ix(w) - match rhs { - PhysicalPlan::IxScan( - IxScan { - schema, - prefix, - arg: Sarg::Eq(ColId(0), AlgebraicValue::U64(5)), - .. - }, - _, - ) => { - assert!(prefix.is_empty()); - assert_eq!(schema.table_id, w_id); - } - plan => panic!("unexpected plan: {:#?}", plan), - } - - // Plan: - // rx - // / \ - // rx w - // / \ - // ix(m) m - let plan = match lhs { - PhysicalPlan::IxJoin( - IxJoin { - lhs, - rhs, - rhs_field: ColId(0), - unique: false, - lhs_field: TupleField { field_pos: 0, .. }, - .. - }, - Semi::Rhs, - ) => { - assert_eq!(rhs.table_id, w_id); - *lhs - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: - // rx - // / \ - // ix(m) m - let plan = match plan { - PhysicalPlan::IxJoin( - IxJoin { - lhs, - rhs, - rhs_field: ColId(1), - unique: false, - lhs_field: TupleField { field_pos: 1, .. }, - .. - }, - Semi::Rhs, - ) => { - assert_eq!(rhs.table_id, m_id); - *lhs - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - // Plan: ix(m) - match plan { - PhysicalPlan::IxScan( - IxScan { - schema, - prefix, - arg: Sarg::Eq(ColId(0), AlgebraicValue::U64(5)), - .. - }, - _, - ) => { - assert!(prefix.is_empty()); - assert_eq!(schema.table_id, m_id); - } - plan => panic!("unexpected plan: {:#?}", plan), - } } /// Test single and multi-column index selections @@ -1657,10 +1404,9 @@ Hash Join let db = SchemaViewer::new(vec![t.clone()]).optimize(true); - let sql = "select * from t where x = 3 and y = 4 and z = 5"; check_sub( &db, - sql, + "select * from t where x = 3 and y = 4 and z = 5", expect![ r#" Index Scan using Index id 2 on t @@ -1668,97 +1414,6 @@ Index Scan using Index id 2 on t Output: t.w, t.x, t.y, t.z"# ], ); - - let lp = parse_and_type_sub(sql, &db).unwrap(); - let pp = compile_project_plan(lp).optimize(); - - // Select index on (x, y, z) - match pp { - ProjectPlan::None(PhysicalPlan::IxScan( - IxScan { - schema, prefix, arg, .. - }, - _, - )) => { - assert_eq!(schema.table_id, t_id); - assert_eq!(arg, Sarg::Eq(ColId(3), AlgebraicValue::U8(5))); - assert_eq!( - prefix, - vec![(ColId(1), AlgebraicValue::U8(3)), (ColId(2), AlgebraicValue::U8(4))] - ); - } - proj => panic!("unexpected plan: {:#?}", proj), - }; - - let sql = "select * from t where x = 3 and y = 4"; - let lp = parse_and_type_sub(sql, &db).unwrap(); - let pp = compile_project_plan(lp).optimize(); - - // Select index on x - let plan = match pp { - ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 2, .. }))); - assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U8(4)))); - *input - } - proj => panic!("unexpected plan: {:#?}", proj), - }; - - match plan { - PhysicalPlan::IxScan( - IxScan { - schema, prefix, arg, .. - }, - _, - ) => { - assert_eq!(schema.table_id, t_id); - assert_eq!(arg, Sarg::Eq(ColId(1), AlgebraicValue::U8(3))); - assert!(prefix.is_empty()); - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - let sql = "select * from t where w = 5 and x = 4"; - let lp = parse_and_type_sub(sql, &db).unwrap(); - let pp = compile_project_plan(lp).optimize(); - - // Select index on x - let plan = match pp { - ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 0, .. }))); - assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U8(5)))); - *input - } - proj => panic!("unexpected plan: {:#?}", proj), - }; - - match plan { - PhysicalPlan::IxScan( - IxScan { - schema, prefix, arg, .. - }, - _, - ) => { - assert_eq!(schema.table_id, t_id); - assert_eq!(arg, Sarg::Eq(ColId(1), AlgebraicValue::U8(4))); - assert!(prefix.is_empty()); - } - plan => panic!("unexpected plan: {:#?}", plan), - }; - - let sql = "select * from t where y = 1"; - let lp = parse_and_type_sub(sql, &db).unwrap(); - let pp = compile_project_plan(lp).optimize(); - - // Do not select index on (y, z) - match pp { - ProjectPlan::None(PhysicalPlan::Filter(input, PhysicalExpr::BinOp(BinOp::Eq, field, value))) => { - assert!(matches!(*input, PhysicalPlan::TableScan(..))); - assert!(matches!(*field, PhysicalExpr::Field(TupleField { field_pos: 2, .. }))); - assert!(matches!(*value, PhysicalExpr::Value(AlgebraicValue::U8(1)))); - } - proj => panic!("unexpected plan: {:#?}", proj), - }; } fn data() -> SchemaViewer { @@ -1796,14 +1451,6 @@ Index Scan using Index id 2 on t SchemaViewer::new(vec![m.clone(), w.clone(), p.clone()]).with_options(ExplainOptions::default().optimize(false)) } - fn check_sub(db: &SchemaViewer, sql: &str, expect: Expect) { - tests_utils::check_sub(db, db.options, sql, expect); - } - - fn check_query(db: &SchemaViewer, sql: &str, expect: Expect) { - tests_utils::check_query(db, db.options, sql, expect); - } - #[test] fn plan_metadata() { let db = data().with_options(ExplainOptions::new().with_schema().with_source().optimize(true)); diff --git a/crates/physical-plan/src/printer.rs b/crates/physical-plan/src/printer.rs index 601528b16d3..da5721961ff 100644 --- a/crates/physical-plan/src/printer.rs +++ b/crates/physical-plan/src/printer.rs @@ -1,17 +1,15 @@ +use crate::plan::{IxScan, Label, PhysicalExpr, PhysicalPlan, ProjectListPlan, ProjectPlan, Sarg, Semi, TupleField}; +use crate::{PhysicalCtx, PlanCtx}; use itertools::Itertools; -use std::collections::{BTreeMap, HashMap}; -use std::fmt; -use std::ops::Bound; - -use crate::plan::{ - IxScan, Label, PhysicalCtx, PhysicalExpr, PhysicalPlan, ProjectListPlan, ProjectPlan, Sarg, Semi, TupleField, -}; use spacetimedb_expr::StatementSource; use spacetimedb_lib::AlgebraicValue; use spacetimedb_primitives::{ColId, ConstraintId, IndexId}; use spacetimedb_schema::def::ConstraintData; use spacetimedb_schema::schema::{IndexSchema, TableSchema}; use spacetimedb_sql_parser::ast::BinOp; +use std::collections::{BTreeMap, HashMap}; +use std::fmt; +use std::ops::Bound; fn range_to_op(lower: &Bound, upper: &Bound) -> BinOp { match (lower, upper) { @@ -239,7 +237,7 @@ pub enum Line<'a> { }, } -impl<'a> Line<'a> { +impl Line<'_> { pub fn ident(&self) -> usize { let ident = match self { Line::TableScan { ident, .. } => *ident, @@ -270,11 +268,8 @@ enum Output<'a> { } impl<'a> Output<'a> { - fn tuples(fields: &[(Box, TupleField)], lines: &Lines<'a>) -> Vec> { - fields - .iter() - .map(|(_, field)| lines.labels.field(field).unwrap()) - .collect() + fn tuples(fields: &[TupleField], lines: &Lines<'a>) -> Vec> { + fields.iter().map(|field| lines.labels.field(field).unwrap()).collect() } fn fields(schema: &Schema<'a>) -> Vec> { @@ -472,18 +467,24 @@ impl<'a> Explain<'a> { /// Evaluate the plan and build the lines to print fn lines(&self) -> Lines<'a> { let mut lines = Lines::new(self.ctx.vars.iter().map(|(x, y)| (*y, x.as_str())).collect()); + match &self.ctx.plan { - ProjectListPlan::Name(ProjectPlan::None(plan)) => { - eval_plan(&mut lines, plan, 0); - } - ProjectListPlan::Name(ProjectPlan::Name(plan, label, _count)) => { - eval_plan(&mut lines, plan, 0); - let schema = lines.labels.table_by_label(label).unwrap(); - lines.output = Output::Star(Output::fields(schema)); - } - ProjectListPlan::List(plan, fields) => { - eval_plan(&mut lines, plan, 0); - lines.output = Output::Fields(Output::tuples(fields, &lines)); + PlanCtx::ProjectList(plan) => match plan { + ProjectListPlan::Name(ProjectPlan::None(plan)) => { + eval_plan(&mut lines, plan, 0); + } + ProjectListPlan::Name(ProjectPlan::Name(plan, label, _count)) => { + eval_plan(&mut lines, plan, 0); + let schema = lines.labels.table_by_label(label).unwrap(); + lines.output = Output::Star(Output::fields(schema)); + } + ProjectListPlan::List(plan, fields) => { + eval_plan(&mut lines, plan, 0); + lines.output = Output::Fields(Output::tuples(fields, &lines)); + } + }, + PlanCtx::DML(plan) => { + todo!() } } @@ -502,7 +503,7 @@ impl<'a> Explain<'a> { } } -impl<'a> fmt::Display for PrintExpr<'a> { +impl fmt::Display for PrintExpr<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.expr { PhysicalExpr::LogOp(op, expr) => { @@ -543,7 +544,7 @@ impl<'a> fmt::Display for PrintExpr<'a> { } } -impl<'a> fmt::Display for PrintSarg<'a> { +impl fmt::Display for PrintSarg<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.expr { Sarg::Eq(lhs, rhs) => { @@ -564,13 +565,13 @@ impl<'a> fmt::Display for PrintSarg<'a> { } } -impl<'a> fmt::Display for Field<'a> { +impl fmt::Display for Field<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}.{}", self.table, self.field) } } -impl<'a> fmt::Display for PrintName<'a> { +impl fmt::Display for PrintName<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PrintName::Named { object, name } => write!(f, "{} {}", object, name), @@ -579,14 +580,14 @@ impl<'a> fmt::Display for PrintName<'a> { } } -impl<'a> fmt::Display for PrintIndex<'a> { +impl fmt::Display for PrintIndex<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}: ", self.name)?; write!(f, "({})", self.cols.iter().join(", ")) } } -impl<'a> fmt::Display for Explain<'a> { +impl fmt::Display for Explain<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let ctx = self.ctx; diff --git a/crates/query/src/delta.rs b/crates/query/src/delta.rs index d42055968fb..c5ce409ba22 100644 --- a/crates/query/src/delta.rs +++ b/crates/query/src/delta.rs @@ -4,9 +4,9 @@ use anyhow::{bail, Result}; use itertools::Either; use spacetimedb_execution::{pipelined::PipelinedProject, Datastore, DeltaStore, Row}; use spacetimedb_expr::check::{type_subscription, SchemaView}; -use spacetimedb_lib::query::Delta; +use spacetimedb_lib::{metrics::ExecutionMetrics, query::Delta}; use spacetimedb_physical_plan::{ - compile::compile_project_plan, + compile::compile_select, plan::{HashJoin, IxJoin, Label, PhysicalPlan, ProjectPlan}, }; use spacetimedb_primitives::TableId; @@ -40,7 +40,7 @@ impl DeltaPlan { let ast = parse_subscription(sql)?; let sub = type_subscription(ast, tx)?; - let Some(table_id) = sub.table_id() else { + let Some(table_id) = sub.return_table_id() else { bail!("Failed to determine TableId for query") }; @@ -48,7 +48,7 @@ impl DeltaPlan { bail!("TableId `{table_id}` does not exist") }; - let plan = compile_project_plan(sub); + let plan = compile_select(sub); let mut ix_joins = true; plan.visit(&mut |plan| match plan { @@ -140,11 +140,15 @@ pub struct DeltaPlanEvaluator { } impl DeltaPlanEvaluator { - pub fn eval_inserts<'a, Tx: Datastore + DeltaStore>(&'a self, tx: &'a Tx) -> Result>> { + pub fn eval_inserts<'a, Tx: Datastore + DeltaStore>( + &'a self, + tx: &'a Tx, + metrics: &mut ExecutionMetrics, + ) -> Result>> { let mut rows = vec![]; for plan in &self.insert_plans { let plan = PipelinedProject::from(plan.clone()); - plan.execute(tx, &mut |row| { + plan.execute(tx, metrics, &mut |row| { rows.push(row); Ok(()) })?; @@ -152,11 +156,15 @@ impl DeltaPlanEvaluator { Ok(rows.into_iter()) } - pub fn eval_deletes<'a, Tx: Datastore + DeltaStore>(&'a self, tx: &'a Tx) -> Result>> { + pub fn eval_deletes<'a, Tx: Datastore + DeltaStore>( + &'a self, + tx: &'a Tx, + metrics: &mut ExecutionMetrics, + ) -> Result>> { let mut rows = vec![]; for plan in &self.delete_plans { let plan = PipelinedProject::from(plan.clone()); - plan.execute(tx, &mut |row| { + plan.execute(tx, metrics, &mut |row| { rows.push(row); Ok(()) })?; diff --git a/crates/query/src/lib.rs b/crates/query/src/lib.rs index bd2ef48e6dc..5fc8f6c5167 100644 --- a/crates/query/src/lib.rs +++ b/crates/query/src/lib.rs @@ -4,11 +4,23 @@ use anyhow::{bail, Result}; use delta::DeltaPlan; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use spacetimedb_client_api_messages::websocket::{ - Compression, DatabaseUpdate, QueryUpdate, TableUpdate, WebsocketFormat, + ByteListLen, Compression, DatabaseUpdate, QueryUpdate, TableUpdate, WebsocketFormat, +}; +use spacetimedb_execution::{ + dml::{MutDatastore, MutExecutor}, + pipelined::{PipelinedProject, ProjectListExecutor}, + Datastore, DeltaStore, +}; +use spacetimedb_expr::{ + check::{type_subscription, SchemaView}, + expr::ProjectList, + statement::{parse_and_type_sql, Statement, DML}, +}; +use spacetimedb_lib::{metrics::ExecutionMetrics, ProductValue}; +use spacetimedb_physical_plan::{ + compile::{compile_dml_plan, compile_select, compile_select_list}, + plan::{ProjectListPlan, ProjectPlan}, }; -use spacetimedb_execution::{pipelined::PipelinedProject, Datastore, DeltaStore}; -use spacetimedb_expr::check::{type_subscription, SchemaView}; -use spacetimedb_physical_plan::{compile::compile_project_plan, plan::ProjectPlan}; use spacetimedb_primitives::TableId; use spacetimedb_sql_parser::parser::sub::parse_subscription; @@ -19,6 +31,39 @@ pub mod delta; /// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions. const MAX_SQL_LENGTH: usize = 50_000; +/// A utility for parsing and type checking a sql statement +pub fn compile_sql_stmt(sql: &str, tx: &impl SchemaView) -> Result { + if sql.len() > MAX_SQL_LENGTH { + bail!("SQL query exceeds maximum allowed length: \"{sql:.120}...\"") + } + Ok(parse_and_type_sql(sql, tx)?) +} + +/// A utility for executing a sql select statement +pub fn execute_select_stmt( + stmt: ProjectList, + tx: &Tx, + metrics: &mut ExecutionMetrics, + check_row_limit: impl Fn(ProjectListPlan) -> Result, +) -> Result> { + let plan = compile_select_list(stmt).optimize(); + let plan = check_row_limit(plan)?; + let plan = ProjectListExecutor::from(plan); + let mut rows = vec![]; + plan.execute(tx, metrics, &mut |row| { + rows.push(row); + Ok(()) + })?; + Ok(rows) +} + +/// A utility for executing a sql dml statement +pub fn execute_dml_stmt(stmt: DML, tx: &mut Tx, metrics: &mut ExecutionMetrics) -> Result<()> { + let plan = compile_dml_plan(stmt).optimize(); + let plan = MutExecutor::from(plan); + plan.execute(tx, metrics) +} + /// A subscription query plan that is NOT used for incremental evaluation #[derive(Debug)] pub struct SubscribePlan { @@ -71,7 +116,7 @@ impl SubscribePlan { let ast = parse_subscription(sql)?; let sub = type_subscription(ast, tx)?; - let Some(table_id) = sub.table_id() else { + let Some(table_id) = sub.return_table_id() else { bail!("Failed to determine TableId for query") }; @@ -79,7 +124,7 @@ impl SubscribePlan { bail!("TableId `{table_id}` does not exist") }; - let plan = compile_project_plan(sub); + let plan = compile_select(sub); let plan = plan.optimize(); Ok(Self { @@ -90,37 +135,48 @@ impl SubscribePlan { } /// Execute a subscription query - pub fn execute(&self, tx: &Tx) -> Result<(F::List, u64)> + pub fn execute(&self, tx: &Tx) -> Result<(F::List, u64, ExecutionMetrics)> where Tx: Datastore + DeltaStore, F: WebsocketFormat, { let plan = PipelinedProject::from(self.plan.clone()); let mut rows = vec![]; - plan.execute(tx, &mut |row| { + let mut metrics = ExecutionMetrics::default(); + plan.execute(tx, &mut metrics, &mut |row| { rows.push(row); Ok(()) })?; - Ok(F::encode_list(rows.into_iter())) + let (list, n) = F::encode_list(rows.into_iter()); + metrics.bytes_scanned += list.num_bytes(); + metrics.bytes_sent_to_clients += list.num_bytes(); + Ok((list, n, metrics)) } /// Execute a subscription query and collect the results in a [TableUpdate] - pub fn collect_table_update(&self, comp: Compression, tx: &Tx) -> Result> + pub fn collect_table_update(&self, comp: Compression, tx: &Tx) -> Result<(TableUpdate, ExecutionMetrics)> where Tx: Datastore + DeltaStore, F: WebsocketFormat, { - self.execute::(tx).map(|(inserts, num_rows)| { + self.execute::(tx).map(|(inserts, num_rows, metrics)| { let deletes = F::List::default(); let qu = QueryUpdate { deletes, inserts }; let update = F::into_query_update(qu, comp); - TableUpdate::new(self.table_id, self.table_name.clone(), (update, num_rows)) + ( + TableUpdate::new(self.table_id, self.table_name.clone(), (update, num_rows)), + metrics, + ) }) } } /// Execute a collection of subscription queries in parallel -pub fn execute_plans(plans: Vec, comp: Compression, tx: &Tx) -> Result> +pub fn execute_plans( + plans: Vec, + comp: Compression, + tx: &Tx, +) -> Result<(DatabaseUpdate, ExecutionMetrics)> where Tx: Datastore + DeltaStore + Sync, F: WebsocketFormat, @@ -128,6 +184,15 @@ where plans .par_iter() .map(|plan| plan.collect_table_update(comp, tx)) - .collect::>() - .map(|tables| DatabaseUpdate { tables }) + .collect::>>() + .map(|table_updates_with_metrics| { + let n = table_updates_with_metrics.len(); + let mut tables = Vec::with_capacity(n); + let mut aggregated_metrics = ExecutionMetrics::default(); + for (update, metrics) in table_updates_with_metrics { + tables.push(update); + aggregated_metrics.merge(metrics); + } + (DatabaseUpdate { tables }, aggregated_metrics) + }) } diff --git a/crates/sats/src/convert.rs b/crates/sats/src/convert.rs index c43647ba354..422a6e1ac5f 100644 --- a/crates/sats/src/convert.rs +++ b/crates/sats/src/convert.rs @@ -42,6 +42,7 @@ built_in_into!(i256, I256); built_in_into!(f32, F32); built_in_into!(f64, F64); built_in_into!(&str, String); +built_in_into!(String, String); built_in_into!(&[u8], Bytes); built_in_into!(Box<[u8]>, Bytes); diff --git a/crates/sats/src/lib.rs b/crates/sats/src/lib.rs index b00aa7badc4..0db23ce3d8c 100644 --- a/crates/sats/src/lib.rs +++ b/crates/sats/src/lib.rs @@ -18,6 +18,7 @@ pub mod product_value; mod resolve_refs; pub mod satn; pub mod ser; +pub mod size_of; pub mod sum_type; pub mod sum_type_variant; pub mod sum_value; diff --git a/crates/sats/src/proptest.rs b/crates/sats/src/proptest.rs index 2b0c883b92b..11f09aca2bc 100644 --- a/crates/sats/src/proptest.rs +++ b/crates/sats/src/proptest.rs @@ -207,6 +207,18 @@ pub fn generate_typed_row() -> impl Strategy impl Strategy)> { + generate_row_type(0..=SIZE).prop_flat_map(move |ty| { + ( + Just(ty.clone()), + vec(generate_product_value(ty), num_rows_min..num_rows_max), + ) + }) +} + /// Generates a type `ty` and a value typed at `ty`. pub fn generate_typed_value() -> impl Strategy { generate_algebraic_type().prop_flat_map(|ty| (Just(ty.clone()), generate_algebraic_value(ty))) diff --git a/crates/sats/src/size_of.rs b/crates/sats/src/size_of.rs new file mode 100644 index 00000000000..9a71374d787 --- /dev/null +++ b/crates/sats/src/size_of.rs @@ -0,0 +1,121 @@ +use ethnum::{i256, u256}; + +use crate::{algebraic_value::Packed, AlgebraicValue, ArrayValue, ProductValue, SumValue, F32, F64}; + +pub trait SizeOf { + /// Returns the unpadded size in bytes of an [AlgebraicValue] or primitive + fn size_of(&self) -> usize; +} + +macro_rules! impl_size_of_primitive { + ($prim:ty) => { + impl SizeOf for $prim { + fn size_of(&self) -> usize { + std::mem::size_of::() + } + } + }; + ($($prim:ty,)*) => { + $(impl_size_of_primitive!($prim);)* + }; +} + +impl_size_of_primitive!( + bool, + u8, + i8, + u16, + i16, + u32, + i32, + u64, + i64, + u128, + i128, + Packed, + Packed, + u256, + i256, + F32, + F64, +); + +impl SizeOf for Box { + fn size_of(&self) -> usize { + self.len() + } +} + +impl SizeOf for AlgebraicValue { + fn size_of(&self) -> usize { + match self { + Self::Min | Self::Max => unreachable!(), + Self::String(x) => x.size_of(), + Self::Bool(x) => x.size_of(), + Self::U8(x) => x.size_of(), + Self::I8(x) => x.size_of(), + Self::U16(x) => x.size_of(), + Self::I16(x) => x.size_of(), + Self::U32(x) => x.size_of(), + Self::I32(x) => x.size_of(), + Self::U64(x) => x.size_of(), + Self::I64(x) => x.size_of(), + Self::U128(x) => x.size_of(), + Self::I128(x) => x.size_of(), + Self::U256(x) => x.size_of(), + Self::I256(x) => x.size_of(), + Self::F32(x) => x.size_of(), + Self::F64(x) => x.size_of(), + Self::Sum(x) => x.size_of(), + Self::Product(x) => x.size_of(), + Self::Array(x) => x.size_of(), + } + } +} + +impl SizeOf for SumValue { + fn size_of(&self) -> usize { + 1 + self.value.size_of() + } +} + +impl SizeOf for ProductValue { + fn size_of(&self) -> usize { + self.elements.size_of() + } +} + +impl SizeOf for [T] +where + T: SizeOf, +{ + fn size_of(&self) -> usize { + self.iter().map(|elt| elt.size_of()).sum() + } +} + +impl SizeOf for ArrayValue { + fn size_of(&self) -> usize { + match self { + Self::Sum(elts) => elts.size_of(), + Self::Product(elts) => elts.size_of(), + Self::Bool(elts) => elts.size_of(), + Self::I8(elts) => elts.size_of(), + Self::U8(elts) => elts.size_of(), + Self::I16(elts) => elts.size_of(), + Self::U16(elts) => elts.size_of(), + Self::I32(elts) => elts.size_of(), + Self::U32(elts) => elts.size_of(), + Self::I64(elts) => elts.size_of(), + Self::U64(elts) => elts.size_of(), + Self::I128(elts) => elts.size_of(), + Self::U128(elts) => elts.size_of(), + Self::I256(elts) => elts.size_of(), + Self::U256(elts) => elts.size_of(), + Self::F32(elts) => elts.size_of(), + Self::F64(elts) => elts.size_of(), + Self::String(elts) => elts.size_of(), + Self::Array(elts) => elts.size_of(), + } + } +} diff --git a/crates/sdk/examples/quickstart-chat/main.rs b/crates/sdk/examples/quickstart-chat/main.rs index 352ea648f70..fe0529122a1 100644 --- a/crates/sdk/examples/quickstart-chat/main.rs +++ b/crates/sdk/examples/quickstart-chat/main.rs @@ -1,5 +1,7 @@ #![allow(clippy::disallowed_macros)] mod module_bindings; +use std::sync::{atomic::AtomicU8, Arc}; + use module_bindings::*; use spacetimedb_client_api_messages::websocket::Compression; @@ -116,7 +118,6 @@ fn on_sub_applied(ctx: &EventContext) { print_message(ctx, &message); } } - // ## Warn if set_name failed /// Our `on_set_name` callback: print a warning if the reducer failed. @@ -178,12 +179,26 @@ fn connect_to_db() -> DbConnection { } // # Subscribe to queries +fn subscribe_to_queries(ctx: &DbConnection, queries: &[&str], callback: fn(&EventContext)) { + if queries.is_empty() { + panic!("No queries to subscribe to."); + } + let remaining_queries = Arc::new(AtomicU8::new(queries.len() as u8)); + for query in queries { + let remaining_queries = remaining_queries.clone(); + ctx.subscription_builder() + .on_applied(move |ctx| { + if remaining_queries.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) == 1 { + callback(ctx); + } + }) + .subscribe(query); + } +} /// Register subscriptions for all rows of both tables. fn subscribe_to_tables(ctx: &DbConnection) { - ctx.subscription_builder() - .on_applied(on_sub_applied) - .subscribe(["SELECT * FROM user;", "SELECT * FROM message;"]); + subscribe_to_queries(ctx, &["SELECT * FROM user", "SELECT * FROM message"], on_sub_applied); } // # Handle user input diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/identity_connected_reducer.rs b/crates/sdk/examples/quickstart-chat/module_bindings/identity_connected_reducer.rs index fafba96a97b..81ad808b6fe 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/identity_connected_reducer.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/identity_connected_reducer.rs @@ -24,17 +24,17 @@ impl __sdk::InModule for IdentityConnectedArgs { pub struct IdentityConnectedCallbackId(__sdk::CallbackId); #[allow(non_camel_case_types)] -/// Extension trait for access to the reducer `__identity_connected__`. +/// Extension trait for access to the reducer `identity_connected`. /// /// Implemented for [`super::RemoteReducers`]. pub trait identity_connected { - /// Request that the remote module invoke the reducer `__identity_connected__` to run as soon as possible. + /// Request that the remote module invoke the reducer `identity_connected` to run as soon as possible. /// /// This method returns immediately, and errors only if we are unable to send the request. /// The reducer will run asynchronously in the future, /// and its status can be observed by listening for [`Self::on_identity_connected`] callbacks. fn identity_connected(&self) -> __anyhow::Result<()>; - /// Register a callback to run whenever we are notified of an invocation of the reducer `__identity_connected__`. + /// Register a callback to run whenever we are notified of an invocation of the reducer `identity_connected`. /// /// The [`super::EventContext`] passed to the `callback` /// will always have [`__sdk::Event::Reducer`] as its `event`, @@ -55,15 +55,14 @@ pub trait identity_connected { impl identity_connected for super::RemoteReducers { fn identity_connected(&self) -> __anyhow::Result<()> { - self.imp - .call_reducer("__identity_connected__", IdentityConnectedArgs {}) + self.imp.call_reducer("identity_connected", IdentityConnectedArgs {}) } fn on_identity_connected( &self, mut callback: impl FnMut(&super::EventContext) + Send + 'static, ) -> IdentityConnectedCallbackId { IdentityConnectedCallbackId(self.imp.on_reducer( - "__identity_connected__", + "identity_connected", Box::new(move |ctx: &super::EventContext| { let super::EventContext { event: @@ -81,19 +80,19 @@ impl identity_connected for super::RemoteReducers { )) } fn remove_on_identity_connected(&self, callback: IdentityConnectedCallbackId) { - self.imp.remove_on_reducer("__identity_connected__", callback.0) + self.imp.remove_on_reducer("identity_connected", callback.0) } } #[allow(non_camel_case_types)] #[doc(hidden)] -/// Extension trait for setting the call-flags for the reducer `__identity_connected__`. +/// Extension trait for setting the call-flags for the reducer `identity_connected`. /// /// Implemented for [`super::SetReducerFlags`]. /// /// This type is currently unstable and may be removed without a major version bump. pub trait set_flags_for_identity_connected { - /// Set the call-reducer flags for the reducer `__identity_connected__` to `flags`. + /// Set the call-reducer flags for the reducer `identity_connected` to `flags`. /// /// This type is currently unstable and may be removed without a major version bump. fn identity_connected(&self, flags: __ws::CallReducerFlags); @@ -101,6 +100,6 @@ pub trait set_flags_for_identity_connected { impl set_flags_for_identity_connected for super::SetReducerFlags { fn identity_connected(&self, flags: __ws::CallReducerFlags) { - self.imp.set_call_reducer_flags("__identity_connected__", flags); + self.imp.set_call_reducer_flags("identity_connected", flags); } } diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/identity_disconnected_reducer.rs b/crates/sdk/examples/quickstart-chat/module_bindings/identity_disconnected_reducer.rs index 0f48475e772..c44bf57c7c7 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/identity_disconnected_reducer.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/identity_disconnected_reducer.rs @@ -24,17 +24,17 @@ impl __sdk::InModule for IdentityDisconnectedArgs { pub struct IdentityDisconnectedCallbackId(__sdk::CallbackId); #[allow(non_camel_case_types)] -/// Extension trait for access to the reducer `__identity_disconnected__`. +/// Extension trait for access to the reducer `identity_disconnected`. /// /// Implemented for [`super::RemoteReducers`]. pub trait identity_disconnected { - /// Request that the remote module invoke the reducer `__identity_disconnected__` to run as soon as possible. + /// Request that the remote module invoke the reducer `identity_disconnected` to run as soon as possible. /// /// This method returns immediately, and errors only if we are unable to send the request. /// The reducer will run asynchronously in the future, /// and its status can be observed by listening for [`Self::on_identity_disconnected`] callbacks. fn identity_disconnected(&self) -> __anyhow::Result<()>; - /// Register a callback to run whenever we are notified of an invocation of the reducer `__identity_disconnected__`. + /// Register a callback to run whenever we are notified of an invocation of the reducer `identity_disconnected`. /// /// The [`super::EventContext`] passed to the `callback` /// will always have [`__sdk::Event::Reducer`] as its `event`, @@ -56,14 +56,14 @@ pub trait identity_disconnected { impl identity_disconnected for super::RemoteReducers { fn identity_disconnected(&self) -> __anyhow::Result<()> { self.imp - .call_reducer("__identity_disconnected__", IdentityDisconnectedArgs {}) + .call_reducer("identity_disconnected", IdentityDisconnectedArgs {}) } fn on_identity_disconnected( &self, mut callback: impl FnMut(&super::EventContext) + Send + 'static, ) -> IdentityDisconnectedCallbackId { IdentityDisconnectedCallbackId(self.imp.on_reducer( - "__identity_disconnected__", + "identity_disconnected", Box::new(move |ctx: &super::EventContext| { let super::EventContext { event: @@ -81,19 +81,19 @@ impl identity_disconnected for super::RemoteReducers { )) } fn remove_on_identity_disconnected(&self, callback: IdentityDisconnectedCallbackId) { - self.imp.remove_on_reducer("__identity_disconnected__", callback.0) + self.imp.remove_on_reducer("identity_disconnected", callback.0) } } #[allow(non_camel_case_types)] #[doc(hidden)] -/// Extension trait for setting the call-flags for the reducer `__identity_disconnected__`. +/// Extension trait for setting the call-flags for the reducer `identity_disconnected`. /// /// Implemented for [`super::SetReducerFlags`]. /// /// This type is currently unstable and may be removed without a major version bump. pub trait set_flags_for_identity_disconnected { - /// Set the call-reducer flags for the reducer `__identity_disconnected__` to `flags`. + /// Set the call-reducer flags for the reducer `identity_disconnected` to `flags`. /// /// This type is currently unstable and may be removed without a major version bump. fn identity_disconnected(&self, flags: __ws::CallReducerFlags); @@ -101,6 +101,6 @@ pub trait set_flags_for_identity_disconnected { impl set_flags_for_identity_disconnected for super::SetReducerFlags { fn identity_disconnected(&self, flags: __ws::CallReducerFlags) { - self.imp.set_call_reducer_flags("__identity_disconnected__", flags); + self.imp.set_call_reducer_flags("identity_disconnected", flags); } } diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/init_reducer.rs b/crates/sdk/examples/quickstart-chat/module_bindings/init_reducer.rs index 50081594cf0..c667dc4193c 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/init_reducer.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/init_reducer.rs @@ -24,17 +24,17 @@ impl __sdk::InModule for InitArgs { pub struct InitCallbackId(__sdk::CallbackId); #[allow(non_camel_case_types)] -/// Extension trait for access to the reducer `__init__`. +/// Extension trait for access to the reducer `init`. /// /// Implemented for [`super::RemoteReducers`]. pub trait init { - /// Request that the remote module invoke the reducer `__init__` to run as soon as possible. + /// Request that the remote module invoke the reducer `init` to run as soon as possible. /// /// This method returns immediately, and errors only if we are unable to send the request. /// The reducer will run asynchronously in the future, /// and its status can be observed by listening for [`Self::on_init`] callbacks. fn init(&self) -> __anyhow::Result<()>; - /// Register a callback to run whenever we are notified of an invocation of the reducer `__init__`. + /// Register a callback to run whenever we are notified of an invocation of the reducer `init`. /// /// The [`super::EventContext`] passed to the `callback` /// will always have [`__sdk::Event::Reducer`] as its `event`, @@ -52,11 +52,11 @@ pub trait init { impl init for super::RemoteReducers { fn init(&self) -> __anyhow::Result<()> { - self.imp.call_reducer("__init__", InitArgs {}) + self.imp.call_reducer("init", InitArgs {}) } fn on_init(&self, mut callback: impl FnMut(&super::EventContext) + Send + 'static) -> InitCallbackId { InitCallbackId(self.imp.on_reducer( - "__init__", + "init", Box::new(move |ctx: &super::EventContext| { let super::EventContext { event: @@ -74,19 +74,19 @@ impl init for super::RemoteReducers { )) } fn remove_on_init(&self, callback: InitCallbackId) { - self.imp.remove_on_reducer("__init__", callback.0) + self.imp.remove_on_reducer("init", callback.0) } } #[allow(non_camel_case_types)] #[doc(hidden)] -/// Extension trait for setting the call-flags for the reducer `__init__`. +/// Extension trait for setting the call-flags for the reducer `init`. /// /// Implemented for [`super::SetReducerFlags`]. /// /// This type is currently unstable and may be removed without a major version bump. pub trait set_flags_for_init { - /// Set the call-reducer flags for the reducer `__init__` to `flags`. + /// Set the call-reducer flags for the reducer `init` to `flags`. /// /// This type is currently unstable and may be removed without a major version bump. fn init(&self, flags: __ws::CallReducerFlags); @@ -94,6 +94,6 @@ pub trait set_flags_for_init { impl set_flags_for_init for super::SetReducerFlags { fn init(&self, flags: __ws::CallReducerFlags) { - self.imp.set_call_reducer_flags("__init__", flags); + self.imp.set_call_reducer_flags("init", flags); } } diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs b/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs index dba79d38c0f..f64a13394ba 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs @@ -53,9 +53,9 @@ impl __sdk::InModule for Reducer { impl __sdk::Reducer for Reducer { fn reducer_name(&self) -> &'static str { match self { - Reducer::IdentityConnected => "__identity_connected__", - Reducer::IdentityDisconnected => "__identity_disconnected__", - Reducer::Init => "__init__", + Reducer::IdentityConnected => "identity_connected", + Reducer::IdentityDisconnected => "identity_disconnected", + Reducer::Init => "init", Reducer::SendMessage { .. } => "send_message", Reducer::SetName { .. } => "set_name", } @@ -65,15 +65,18 @@ impl TryFrom<__ws::ReducerCallInfo<__ws::BsatnFormat>> for Reducer { type Error = __anyhow::Error; fn try_from(value: __ws::ReducerCallInfo<__ws::BsatnFormat>) -> __anyhow::Result { match &value.reducer_name[..] { - "__identity_connected__" => Ok(__sdk::parse_reducer_args::< - identity_connected_reducer::IdentityConnectedArgs, - >("__identity_connected__", &value.args)? - .into()), - "__identity_disconnected__" => Ok(__sdk::parse_reducer_args::< + "identity_connected" => Ok( + __sdk::parse_reducer_args::( + "identity_connected", + &value.args, + )? + .into(), + ), + "identity_disconnected" => Ok(__sdk::parse_reducer_args::< identity_disconnected_reducer::IdentityDisconnectedArgs, - >("__identity_disconnected__", &value.args)? + >("identity_disconnected", &value.args)? .into()), - "__init__" => Ok(__sdk::parse_reducer_args::("__init__", &value.args)?.into()), + "init" => Ok(__sdk::parse_reducer_args::("init", &value.args)?.into()), "send_message" => Ok(__sdk::parse_reducer_args::( "send_message", &value.args, @@ -396,6 +399,7 @@ impl __sdk::EventContext for EventContext { /// A handle on a subscribed query. // TODO: Document this better after implementing the new subscription API. +#[derive(Clone)] pub struct SubscriptionHandle { imp: __sdk::SubscriptionHandleImpl, } @@ -408,6 +412,26 @@ impl __sdk::SubscriptionHandle for SubscriptionHandle { fn new(imp: __sdk::SubscriptionHandleImpl) -> Self { Self { imp } } + + /// Returns true if this subscription has been terminated due to an unsubscribe call or an error. + fn is_ended(&self) -> bool { + self.imp.is_ended() + } + + /// Returns true if this subscription has been applied and has not yet been unsubscribed. + fn is_active(&self) -> bool { + self.imp.is_active() + } + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`, + /// then run `on_end` when its rows are removed from the client cache. + fn unsubscribe_then(self, on_end: __sdk::OnEndedCallback) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(Some(on_end)) + } + + fn unsubscribe(self) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(None) + } } /// Alias trait for a [`__sdk::DbContext`] connected to this module, diff --git a/crates/sdk/src/db_connection.rs b/crates/sdk/src/db_connection.rs index aa8b5856224..1d207c0de45 100644 --- a/crates/sdk/src/db_connection.rs +++ b/crates/sdk/src/db_connection.rs @@ -22,7 +22,9 @@ use crate::{ callbacks::{CallbackId, DbCallbacks, ReducerCallback, ReducerCallbacks, RowCallback, UpdateCallback}, client_cache::{ClientCache, TableHandle}, spacetime_module::{DbConnection, DbUpdate, EventContext, InModule, SpacetimeModule}, - subscription::{OnAppliedCallback, OnErrorCallback, SubscriptionManager}, + subscription::{ + OnAppliedCallback, OnErrorCallback, PendingUnsubscribeResult, SubscriptionHandleImpl, SubscriptionManager, + }, websocket::{WsConnection, WsParams}, Event, ReducerEvent, Status, }; @@ -36,7 +38,7 @@ use spacetimedb_client_api_messages::websocket::{BsatnFormat, CallReducerFlags, use spacetimedb_lib::{bsatn, ser::Serialize, Address, Identity}; use std::{ collections::HashMap, - sync::{Arc, Mutex as StdMutex, OnceLock}, + sync::{atomic::AtomicU32, Arc, Mutex as StdMutex, OnceLock}, time::{Duration, SystemTime}, }; use tokio::{ @@ -57,6 +59,9 @@ pub struct DbContextImpl { /// All the state which is safe to hold a lock on while running callbacks. pub(crate) inner: SharedCell>, + /// None if we have disconnected. + pub(crate) send_chan: SharedCell>>>, + /// The client cache, which stores subscribed rows. cache: SharedCell>, @@ -89,6 +94,7 @@ impl Clone for DbContextImpl { // since we'll be doing `DbContextImpl::clone` very frequently, // and we need it to be fast. inner: Arc::clone(&self.inner), + send_chan: Arc::clone(&self.send_chan), cache: Arc::clone(&self.cache), recv: Arc::clone(&self.recv), pending_mutations_send: self.pending_mutations_send.clone(), @@ -136,18 +142,17 @@ impl DbContextImpl { // Subscription applied: // set the received state to store all the rows, // then invoke the on-applied and row callbacks. + // We only use this for `subscribe_from_all_tables` ParsedMessage::InitialSubscription { db_update, sub_id } => { // Lock the client cache in a restricted scope, // so that it will be unlocked when callbacks run. { let mut cache = self.cache.lock().unwrap(); - // FIXME: delete no-longer-subscribed rows. db_update.apply_to_client_cache(&mut *cache); } let event_ctx = self.make_event_ctx(Event::SubscribeApplied); let mut inner = self.inner.lock().unwrap(); - inner.subscriptions.subscription_applied(&event_ctx, sub_id); - // FIXME: invoke delete callbacks for no-longer-subscribed rows. + inner.subscriptions.legacy_subscription_applied(&event_ctx, sub_id); db_update.invoke_row_callbacks(&event_ctx, &mut inner.db_callbacks); Ok(()) } @@ -185,6 +190,51 @@ impl DbContextImpl { } Ok(()) } + ParsedMessage::SubscribeApplied { + query_id, + initial_update, + } => { + // Lock the client cache in a restricted scope, + // so that it will be unlocked when callbacks run. + { + let mut cache = self.cache.lock().unwrap(); + initial_update.apply_to_client_cache(&mut *cache); + } + let event_ctx = self.make_event_ctx(Event::SubscribeApplied); + let mut inner = self.inner.lock().unwrap(); + inner.subscriptions.subscription_applied(&event_ctx, query_id); + // FIXME: implement ref counting of rows to handle queries with overlapping results. + initial_update.invoke_row_callbacks(&event_ctx, &mut inner.db_callbacks); + Ok(()) + } + ParsedMessage::UnsubscribeApplied { + query_id, + initial_update, + } => { + // Lock the client cache in a restricted scope, + // so that it will be unlocked when callbacks run. + { + let mut cache = self.cache.lock().unwrap(); + initial_update.apply_to_client_cache(&mut *cache); + } + let event_ctx = self.make_event_ctx(Event::UnsubscribeApplied); + let mut inner = self.inner.lock().unwrap(); + inner.subscriptions.unsubscribe_applied(&event_ctx, query_id); + // FIXME: implement ref counting of rows to handle queries with overlapping results. + initial_update.invoke_row_callbacks(&event_ctx, &mut inner.db_callbacks); + Ok(()) + } + ParsedMessage::SubscriptionError { query_id, error } => { + let Some(query_id) = query_id else { + // A subscription error that isn't specific to a query is a fatal error. + self.invoke_disconnected(Some(&anyhow::anyhow!(error))); + return Ok(()); + }; + let mut inner = self.inner.lock().unwrap(); + let event_ctx = self.make_event_ctx(Event::SubscribeError(anyhow::anyhow!(error))); + inner.subscriptions.subscription_error(&event_ctx, query_id); + Ok(()) + } }; res @@ -192,23 +242,24 @@ impl DbContextImpl { /// Invoke the on-disconnect callback, and mark [`Self::is_active`] false. fn invoke_disconnected(&self, err: Option<&anyhow::Error>) { - let disconnected_callback = { - let mut inner = self.inner.lock().unwrap(); - // TODO: Determine correct behavior here. - // - Delete all rows from client cache? - // - Invoke `on_disconnect` methods? - // - End all subscriptions and invoke their `on_error` methods? - - // Set `send_chan` to `None`, since `Self::is_active` checks that. - inner.send_chan = None; - - // Grap the `on_disconnect` callback and invoke it. - inner.on_disconnect.take() - }; - if let Some(disconnect_callback) = disconnected_callback { + let mut inner = self.inner.lock().unwrap(); + // When we disconnect, we first call the on_disconnect method, + // then we call the `on_error` method for all subscriptions. + // We don't change the client cache at all. + + // Set `send_chan` to `None`, since `Self::is_active` checks that. + *self.send_chan.lock().unwrap() = None; + + // Grap the `on_disconnect` callback and invoke it. + if let Some(disconnect_callback) = inner.on_disconnect.take() { let ctx = ::new(self.clone()); disconnect_callback(&ctx, err); } + + // Call the `on_disconnect` method for all subscriptions. + inner + .subscriptions + .on_disconnect(&self.make_event_ctx(Event::Disconnected)); } fn make_event_ctx(&self, event: Event) -> M::EventContext { @@ -236,9 +287,12 @@ impl DbContextImpl { on_error, } => { let mut inner = self.inner.lock().unwrap(); - inner.subscriptions.register_subscription(sub_id, on_applied, on_error); inner - .send_chan + .subscriptions + .register_legacy_subscription(sub_id, on_applied, on_error); + self.send_chan + .lock() + .unwrap() .as_mut() .ok_or(DisconnectedError {})? .unbounded_send(ws::ClientMessage::Subscribe(ws::Subscribe { @@ -247,6 +301,47 @@ impl DbContextImpl { })) .expect("Unable to send subscribe message: WS sender loop has dropped its recv channel"); } + // Subscribe: register the subscription in the [`SubscriptionManager`] + // and send the `Subscribe` WS message. + PendingMutation::SubscribeSingle { query_id, handle } => { + let mut inner = self.inner.lock().unwrap(); + // Register the subscription, so we can handle related messages from the server. + inner.subscriptions.register_subscription(query_id, handle.clone()); + if let Some(msg) = handle.start() { + self.send_chan + .lock() + .unwrap() + .as_mut() + .ok_or(DisconnectedError {})? + .unbounded_send(ws::ClientMessage::SubscribeSingle(msg)) + .expect("Unable to send subscribe message: WS sender loop has dropped its recv channel"); + } + // else, the handle was already cancelled. + } + + PendingMutation::Unsubscribe { query_id } => { + let mut inner = self.inner.lock().unwrap(); + match inner.subscriptions.handle_pending_unsubscribe(query_id) { + PendingUnsubscribeResult::DoNothing => + // The subscription was already unsubscribed, so we don't need to send an unsubscribe message. + { + return Ok(()) + } + + PendingUnsubscribeResult::RunCallback(callback) => { + callback(&self.make_event_ctx(Event::UnsubscribeApplied)); + } + PendingUnsubscribeResult::SendUnsubscribe(m) => { + self.send_chan + .lock() + .unwrap() + .as_mut() + .ok_or(DisconnectedError {})? + .unbounded_send(ws::ClientMessage::Unsubscribe(m)) + .expect("Unable to send unsubscribe message: WS sender loop has dropped its recv channel"); + } + } + } // CallReducer: send the `CallReducer` WS message. PendingMutation::CallReducer { reducer, args_bsatn } => { @@ -259,8 +354,9 @@ impl DbContextImpl { request_id: 0, flags, }); - inner - .send_chan + self.send_chan + .lock() + .unwrap() .as_mut() .ok_or(DisconnectedError {})? .unbounded_send(msg) @@ -273,7 +369,7 @@ impl DbContextImpl { // This will close the WebSocket loop in websocket.rs, // sending a close frame to the server, // eventually resulting in disconnect callbacks being called. - self.inner.lock().unwrap().send_chan = None; + *self.send_chan.lock().unwrap() = None; } // Callback stuff: these all do what you expect. @@ -486,7 +582,7 @@ impl DbContextImpl { /// Called by the autogenerated `DbConnection` method of the same name. pub fn is_active(&self) -> bool { - self.inner.lock().unwrap().send_chan.is_some() + self.send_chan.lock().unwrap().is_some() } /// Called by the autogenerated `DbConnection` method of the same name. @@ -589,9 +685,6 @@ pub(crate) struct DbContextImplInner { #[allow(unused)] runtime: Option, - /// None if we have disconnected. - send_chan: Option>>, - db_callbacks: DbCallbacks, reducer_callbacks: ReducerCallbacks, pub(crate) subscriptions: SubscriptionManager, @@ -746,7 +839,6 @@ but you must call one of them, or else the connection will never progress. let inner = Arc::new(StdMutex::new(DbContextImplInner { runtime, - send_chan: Some(raw_msg_send), db_callbacks, reducer_callbacks, subscriptions: SubscriptionManager::default(), @@ -760,11 +852,13 @@ but you must call one of them, or else the connection will never progress. let mut cache = ClientCache::default(); M::register_tables(&mut cache); let cache = Arc::new(StdMutex::new(cache)); + let send_chan = Arc::new(StdMutex::new(Some(raw_msg_send))); let (pending_mutations_send, pending_mutations_recv) = mpsc::unbounded(); let ctx_imp = DbContextImpl { runtime: handle, inner, + send_chan, cache, recv: Arc::new(TokioMutex::new(parsed_recv_chan)), pending_mutations_send, @@ -906,6 +1000,9 @@ enum ParsedMessage { InitialSubscription { db_update: M::DbUpdate, sub_id: u32 }, TransactionUpdate(Event, Option), IdentityToken(Identity, Box, Address), + SubscribeApplied { query_id: u32, initial_update: M::DbUpdate }, + UnsubscribeApplied { query_id: u32, initial_update: M::DbUpdate }, + SubscriptionError { query_id: Option, error: String }, Error(anyhow::Error), } @@ -976,9 +1073,34 @@ async fn parse_loop( ws::ServerMessage::OneOffQueryResponse(_) => { unreachable!("The Rust SDK does not implement one-off queries") } - ws::ServerMessage::SubscribeApplied(_) => todo!(), - ws::ServerMessage::UnsubscribeApplied(_) => todo!(), - ws::ServerMessage::SubscriptionError(_) => todo!(), + ws::ServerMessage::SubscribeApplied(subscribe_applied) => { + let table_rows = subscribe_applied.rows.table_rows; + let db_update = ws::DatabaseUpdate::from_iter(std::iter::once(table_rows)); + let query_id = subscribe_applied.query_id.id; + match M::DbUpdate::parse_update(db_update) { + Err(e) => ParsedMessage::Error(e.context("Failed to parse update from SubscribeApplied")), + Ok(initial_update) => ParsedMessage::SubscribeApplied { + query_id, + initial_update, + }, + } + } + ws::ServerMessage::UnsubscribeApplied(unsubscribe_applied) => { + let table_rows = unsubscribe_applied.rows.table_rows; + let db_update = ws::DatabaseUpdate::from_iter(std::iter::once(table_rows)); + let query_id = unsubscribe_applied.query_id.id; + match M::DbUpdate::parse_update(db_update) { + Err(e) => ParsedMessage::Error(e.context("Failed to parse update from UnsubscribeApplied")), + Ok(initial_update) => ParsedMessage::UnsubscribeApplied { + query_id, + initial_update, + }, + } + } + ws::ServerMessage::SubscriptionError(e) => ParsedMessage::SubscriptionError { + query_id: e.query_id, + error: e.error.to_string(), + }, }) .expect("Failed to send ParsedMessage to main thread"); } @@ -993,7 +1115,13 @@ pub(crate) enum PendingMutation { // TODO: replace `queries` with query_sql: String, sub_id: u32, }, - // TODO: Unsubscribe { ??? }, + Unsubscribe { + query_id: u32, + }, + SubscribeSingle { + query_id: u32, + handle: SubscriptionHandleImpl, + }, CallReducer { reducer: &'static str, args_bsatn: Vec, @@ -1061,3 +1189,25 @@ impl std::error::Error for DisconnectedError {} fn error_is_normal_disconnect(e: &anyhow::Error) -> bool { e.is::() } + +static NEXT_REQUEST_ID: AtomicU32 = AtomicU32::new(1); + +// Get the next request ID to use for a WebSocket message. +pub(crate) fn next_request_id() -> u32 { + NEXT_REQUEST_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed) +} + +static NEXT_SUBSCRIPTION_ID: AtomicU32 = AtomicU32::new(1); + +// Get the next request ID to use for a WebSocket message. +pub(crate) fn next_subscription_id() -> u32 { + NEXT_SUBSCRIPTION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed) +} + +#[cfg(test)] +mod tests { + #[test] + fn dummy() { + assert_eq!(1, 1); + } +} diff --git a/crates/sdk/src/event.rs b/crates/sdk/src/event.rs index a3bb18b1c69..075697fc2fc 100644 --- a/crates/sdk/src/event.rs +++ b/crates/sdk/src/event.rs @@ -36,6 +36,9 @@ pub enum Event { /// and to row delete callbacks resulting from the ended subscription. UnsubscribeApplied, + /// Event when a subscription was ended by a disconnection. + Disconnected, + /// Event when an error causes one or more of our subscriptions to end prematurely, /// or to never be started. /// diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index c531aa26473..461dba83bf4 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -27,6 +27,7 @@ pub use db_context::DbContext; pub use event::{Event, ReducerEvent, Status}; pub use table::{Table, TableWithPrimaryKey}; +pub use spacetime_module::SubscriptionHandle; pub use spacetimedb_lib::{Address, Identity, ScheduleAt}; pub use spacetimedb_sats::{i256, u256}; @@ -50,7 +51,7 @@ pub mod __codegen { parse_reducer_args, DbConnection, DbUpdate, EventContext, InModule, Reducer, SpacetimeModule, SubscriptionHandle, TableUpdate, }; - pub use crate::subscription::{SubscriptionBuilder, SubscriptionHandleImpl}; + pub use crate::subscription::{OnEndedCallback, SubscriptionBuilder, SubscriptionHandleImpl}; pub use crate::{ Address, DbConnectionBuilder, DbContext, DisconnectedError, Event, Identity, ReducerEvent, ScheduleAt, Table, TableWithPrimaryKey, diff --git a/crates/sdk/src/spacetime_module.rs b/crates/sdk/src/spacetime_module.rs index 1f7a1c69446..aaa59835ea8 100644 --- a/crates/sdk/src/spacetime_module.rs +++ b/crates/sdk/src/spacetime_module.rs @@ -3,8 +3,11 @@ //! This module is internal, and may incompatibly change without warning. use crate::{ - callbacks::DbCallbacks, client_cache::ClientCache, db_connection::DbContextImpl, - subscription::SubscriptionHandleImpl, Event, + callbacks::DbCallbacks, + client_cache::ClientCache, + db_connection::DbContextImpl, + subscription::{OnEndedCallback, SubscriptionHandleImpl}, + Event, }; use anyhow::Context; use bytes::Bytes; @@ -109,11 +112,25 @@ where fn reducer_name(&self) -> &'static str; } -pub trait SubscriptionHandle: InModule + Send + 'static +pub trait SubscriptionHandle: InModule + Clone + Send + 'static where Self::Module: SpacetimeModule, { fn new(imp: SubscriptionHandleImpl) -> Self; + fn is_ended(&self) -> bool; + + fn is_active(&self) -> bool; + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`, + /// then run `on_end` when its rows are removed from the client cache. + /// Returns an error if the subscription is already ended, + /// or if unsubscribe has already been called. + fn unsubscribe_then(self, on_end: OnEndedCallback) -> anyhow::Result<()>; + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`. + /// Returns an error if the subscription is already ended, + /// or if unsubscribe has already been called. + fn unsubscribe(self) -> anyhow::Result<()>; } pub struct WithBsatn { diff --git a/crates/sdk/src/subscription.rs b/crates/sdk/src/subscription.rs index a34afdebc9c..ad017325985 100644 --- a/crates/sdk/src/subscription.rs +++ b/crates/sdk/src/subscription.rs @@ -3,39 +3,72 @@ //! This module is internal, and may incompatibly change without warning. use crate::{ - db_connection::{DbContextImpl, PendingMutation}, + db_connection::{next_request_id, next_subscription_id, DbContextImpl, PendingMutation}, spacetime_module::{SpacetimeModule, SubscriptionHandle}, }; +use anyhow::bail; +use futures_channel::mpsc; +use spacetimedb_client_api_messages::websocket::{self as ws}; use spacetimedb_data_structures::map::HashMap; -use std::sync::atomic::AtomicU32; +use std::sync::{atomic::AtomicU32, Arc, Mutex}; // TODO: Rewrite for subscription manipulation, once we get that. // Currently race conditions abound, as you may resubscribe before the prev sub was applied, // clobbering your previous callback. pub struct SubscriptionManager { - subscriptions: HashMap>, + legacy_subscriptions: HashMap>, + new_subscriptions: HashMap>, } impl Default for SubscriptionManager { fn default() -> Self { Self { - subscriptions: HashMap::default(), + legacy_subscriptions: HashMap::default(), + new_subscriptions: HashMap::default(), } } } pub(crate) type OnAppliedCallback = Box::EventContext) + Send + 'static>; pub(crate) type OnErrorCallback = Box::EventContext) + Send + 'static>; +pub type OnEndedCallback = Box::EventContext) + Send + 'static>; + +/// When handling a pending unsubscribe, there are three cases the caller must handle. +pub(crate) enum PendingUnsubscribeResult { + // The unsubscribe message should be sent to the server. + SendUnsubscribe(ws::Unsubscribe), + // The subscription is immediately being cancelled, so the callback should be run. + RunCallback(OnEndedCallback), + // No action is required. + DoNothing, +} impl SubscriptionManager { - pub(crate) fn register_subscription( + pub(crate) fn on_disconnect(&mut self, ctx: &M::EventContext) { + // We need to clear all the subscriptions. + // We should run the on_ended callbacks for all of them. + for (_, mut sub) in self.new_subscriptions.drain() { + if let Some(callback) = sub.on_error() { + callback(ctx); + } + } + for (_, mut s) in self.legacy_subscriptions.drain() { + if let Some(callback) = s.on_error.take() { + callback(ctx); + } + } + } + + /// Register a new subscription. This does not send the subscription to the server. + /// Rather, it makes the subscription available for the next `apply_subscriptions` call. + pub(crate) fn register_legacy_subscription( &mut self, sub_id: u32, on_applied: Option>, on_error: Option>, ) { - self.subscriptions + self.legacy_subscriptions .try_insert( sub_id, SubscribedQuery { @@ -46,13 +79,86 @@ impl SubscriptionManager { ) .unwrap_or_else(|_| unreachable!("Duplicate subscription id {sub_id}")); } - pub(crate) fn subscription_applied(&mut self, ctx: &M::EventContext, sub_id: u32) { - let sub = self.subscriptions.get_mut(&sub_id).unwrap(); + + pub(crate) fn legacy_subscription_applied(&mut self, ctx: &M::EventContext, sub_id: u32) { + let sub = self.legacy_subscriptions.get_mut(&sub_id).unwrap(); sub.is_applied = true; if let Some(callback) = sub.on_applied.take() { callback(ctx); } } + + /// Register a new subscription. This does not send the subscription to the server. + /// Rather, it makes the subscription available for the next `apply_subscriptions` call. + pub(crate) fn register_subscription(&mut self, query_id: u32, handle: SubscriptionHandleImpl) { + self.new_subscriptions + .try_insert(query_id, handle.clone()) + .unwrap_or_else(|_| unreachable!("Duplicate subscription id {query_id}")); + } + + /// This should be called when we get a subscription applied message from the server. + pub(crate) fn subscription_applied(&mut self, ctx: &M::EventContext, sub_id: u32) { + let Some(sub) = self.new_subscriptions.get_mut(&sub_id) else { + // TODO: log or double check error handling. + return; + }; + if let Some(callback) = sub.on_applied() { + callback(ctx) + } + } + + /// This should be called when we get a subscription applied message from the server. + pub(crate) fn handle_pending_unsubscribe(&mut self, sub_id: u32) -> PendingUnsubscribeResult { + let Some(sub) = self.new_subscriptions.get(&sub_id) else { + // TODO: log or double check error handling. + return PendingUnsubscribeResult::DoNothing; + }; + let mut sub = sub.clone(); + if sub.is_cancelled() { + // This means that the subscription was cancelled before it was started. + // We skip sending the subscription start message. + self.new_subscriptions.remove(&sub_id); + if let Some(callback) = sub.on_ended() { + return PendingUnsubscribeResult::RunCallback(callback); + } else { + return PendingUnsubscribeResult::DoNothing; + } + } + if sub.is_ended() { + // This should only happen if the subscription was ended due to an error. + // We don't need to send an unsubscribe message in this case. + self.new_subscriptions.remove(&sub_id); + return PendingUnsubscribeResult::DoNothing; + } + PendingUnsubscribeResult::SendUnsubscribe(ws::Unsubscribe { + query_id: ws::QueryId::new(sub_id), + request_id: next_request_id(), + }) + } + + /// This should be called when we get an unsubscribe applied message from the server. + pub(crate) fn unsubscribe_applied(&mut self, ctx: &M::EventContext, sub_id: u32) { + let Some(mut sub) = self.new_subscriptions.remove(&sub_id) else { + // TODO: double check error handling. + log::debug!("Unsubscribe applied called for missing query {:?}", sub_id); + return; + }; + if let Some(callback) = sub.on_ended() { + callback(ctx) + } + } + + /// This should be called when we get an unsubscribe applied message from the server. + pub(crate) fn subscription_error(&mut self, ctx: &M::EventContext, sub_id: u32) { + let Some(mut sub) = self.new_subscriptions.remove(&sub_id) else { + // TODO: double check error handling. + log::warn!("Unsubscribe applied called for missing query {:?}", sub_id); + return; + }; + if let Some(callback) = sub.on_error() { + callback(ctx) + } + } } struct SubscribedQuery { @@ -102,29 +208,23 @@ impl SubscriptionBuilder { self } - /// Subscribe to `queries`, which should be a collection of SQL queries, - /// each of which is a single-table non-projected `SELECT` statement - /// with an optional `WHERE` clause, - /// and `JOIN`ed with at most one other table as a filter. - pub fn subscribe(self, queries: impl IntoQueries) -> M::SubscriptionHandle { - static NEXT_SUB_ID: AtomicU32 = AtomicU32::new(0); - - let sub_id = NEXT_SUB_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - let Self { - on_applied, - on_error, - conn, - } = self; - conn.pending_mutations_send - .unbounded_send(PendingMutation::Subscribe { - on_applied, - on_error, - queries: queries.into_queries(), - sub_id, + pub fn subscribe(self, query_sql: &str) -> M::SubscriptionHandle { + let qid = next_subscription_id(); + let handle = SubscriptionHandleImpl::new(SubscriptionState::new( + qid, + query_sql.into(), + self.conn.pending_mutations_send.clone(), + self.on_applied, + self.on_error, + )); + self.conn + .pending_mutations_send + .unbounded_send(PendingMutation::SubscribeSingle { + query_id: qid, + handle: handle.clone(), }) .unwrap(); - M::SubscriptionHandle::new(SubscriptionHandleImpl { conn, sub_id }) + M::SubscriptionHandle::new(handle) } /// Subscribe to all rows from all tables. @@ -143,7 +243,24 @@ impl SubscriptionBuilder { /// or vice versa, may misbehave in any number of ways, /// including dropping subscriptions, corrupting the client cache, or panicking. pub fn subscribe_to_all_tables(self) { - self.subscribe(["SELECT * FROM *"]); + static NEXT_SUB_ID: AtomicU32 = AtomicU32::new(0); + + let sub_id = NEXT_SUB_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let query = "SELECT * FROM *"; + + let Self { + on_applied, + on_error, + conn, + } = self; + conn.pending_mutations_send + .unbounded_send(PendingMutation::Subscribe { + on_applied, + on_error, + queries: [query].into_queries(), + sub_id, + }) + .unwrap(); } } @@ -171,55 +288,212 @@ impl>, const N: usize> IntoQueries for [S; N] { } } +/// This tracks what messages have been exchanged with the server. +#[derive(Debug, PartialEq, Eq, Clone)] +enum SubscriptionServerState { + Pending, // This hasn't been sent to the server yet. + Sent, // We have sent it to the server. + Applied, // The server has acknowledged it, and we are receiving updates. + Ended, // It has been unapplied. + Error, // There was an error that ended the subscription. +} + +/// We track the state of a subscription here. +/// A reference to this is held by the `SubscriptionHandle` that clients use to unsubscribe, +/// and by the `SubscriptionManager` that handles updates from the server. +pub(crate) struct SubscriptionState { + query_id: u32, + query_sql: Box, + unsubscribe_called: bool, + status: SubscriptionServerState, + on_applied: Option>, + on_error: Option>, + on_ended: Option>, + // This is needed to schedule client operations. + // Note that we shouldn't have a full connection here. + pending_mutation_sender: mpsc::UnboundedSender>, +} + +impl SubscriptionState { + pub(crate) fn new( + query_id: u32, + query_sql: Box, + pending_mutation_sender: mpsc::UnboundedSender>, + on_applied: Option>, + on_error: Option>, + ) -> Self { + Self { + query_id, + query_sql, + unsubscribe_called: false, + status: SubscriptionServerState::Pending, + on_applied, + on_error, + on_ended: None, + pending_mutation_sender, + } + } + + /// Start the subscription. + /// This updates the state in the handle, and returns the message to be sent to the server. + /// The caller is responsible for sending the message to the server. + pub(crate) fn start(&mut self) -> Option { + if self.unsubscribe_called { + // This means that the subscription was cancelled before it was started. + // We skip sending the subscription start message. + return None; + } + if self.status != SubscriptionServerState::Pending { + // This should never happen. + // We should only start a subscription once. + // If we are starting it again, we have a bug. + unreachable!("Subscription already started"); + } + self.status = SubscriptionServerState::Sent; + Some(ws::SubscribeSingle { + query_id: ws::QueryId::new(self.query_id), + query: self.query_sql.clone(), + request_id: next_request_id(), + }) + } + + pub fn unsubscribe_then(&mut self, on_end: Option>) -> anyhow::Result<()> { + // pub fn unsubscribe_then(&mut self, on_end: impl FnOnce(&M::EventContext) + Send + 'static) -> anyhow::Result<()> { + if self.is_ended() { + bail!("Subscription has already ended"); + } + // Check if it has already been called. + if self.unsubscribe_called { + bail!("Unsubscribe already called"); + } + + self.unsubscribe_called = true; + self.on_ended = on_end; + // self.on_ended = Some(Box::new(on_end)); + + // We send this even if the status is still Pending, so we can remove it from the manager. + self.pending_mutation_sender + .unbounded_send(PendingMutation::Unsubscribe { + query_id: self.query_id, + }) + .unwrap(); + Ok(()) + } + + /// Check if the client ended the subscription before we sent anything to the server. + pub fn is_cancelled(&self) -> bool { + self.status == SubscriptionServerState::Pending && self.unsubscribe_called + } + + pub fn is_ended(&self) -> bool { + matches!( + self.status, + SubscriptionServerState::Ended | SubscriptionServerState::Error + ) + } + + pub fn is_active(&self) -> bool { + match self.status { + SubscriptionServerState::Applied => !self.unsubscribe_called, + _ => false, + } + } + + pub fn on_applied(&mut self) -> Option> { + if self.status != SubscriptionServerState::Sent { + // Potentially log a warning. This might make sense if we are shutting down. + log::debug!( + "on_applied called for query {:?} with status: {:?}", + self.query_id, + self.status + ); + return None; + } + log::debug!("on_applied called for query {:?}", self.query_id); + self.status = SubscriptionServerState::Applied; + self.on_applied.take() + } + + pub fn on_ended(&mut self) -> Option> { + // TODO: Consider logging a warning if the state is wrong (like being in the Error state). + if self.is_ended() { + return None; + } + self.status = SubscriptionServerState::Ended; + self.on_ended.take() + } + + pub fn on_error(&mut self) -> Option> { + // TODO: Consider logging a warning if the state is wrong. + if self.is_ended() { + return None; + } + self.status = SubscriptionServerState::Error; + self.on_error.take() + } +} + #[doc(hidden)] /// Internal implementation held by the module-specific generated `SubscriptionHandle` type. pub struct SubscriptionHandleImpl { - conn: DbContextImpl, - #[allow(unused)] - sub_id: u32, + pub(crate) inner: Arc>>, +} + +impl Clone for SubscriptionHandleImpl { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } } impl SubscriptionHandleImpl { - /// Called by the `SubscriptionHandle` method of the same name. + pub(crate) fn new(inner: SubscriptionState) -> Self { + Self { + inner: Arc::new(Mutex::new(inner)), + } + } + + pub(crate) fn start(&self) -> Option { + let mut inner = self.inner.lock().unwrap(); + inner.start() + } + + pub(crate) fn is_cancelled(&self) -> bool { + self.inner.lock().unwrap().is_cancelled() + } pub fn is_ended(&self) -> bool { - // When a subscription ends, we remove its `SubscribedQuery` from the `SubscriptionManager`. - // So, to check if a subscription has ended, we check if the entry is present. - // TODO: Note that we never end a subscription currently. - // This will change with the implementation of the subscription management proposal. - !self - .conn - .inner - .lock() - .unwrap() - .subscriptions - .subscriptions - .contains_key(&self.sub_id) + self.inner.lock().unwrap().is_ended() } - /// Called by the `SubscriptionHandle` method of the same name. pub fn is_active(&self) -> bool { - // A subscription is active if: - // - It has not yet ended, i.e. is still present in the `SubscriptionManager`. - // - It has been applied. - self.conn - .inner - .lock() - .unwrap() - .subscriptions - .subscriptions - .get(&self.sub_id) - .map(|sub| sub.is_applied) - .unwrap_or(false) + self.inner.lock().unwrap().is_active() } /// Called by the `SubscriptionHandle` method of the same name. - pub fn unsubscribe(self) -> anyhow::Result<()> { - self.unsubscribe_then(|_| {}) + pub fn unsubscribe_then(self, on_end: Option>) -> anyhow::Result<()> { + let mut inner = self.inner.lock().unwrap(); + inner.unsubscribe_then(on_end) } - /// Called by the `SubscriptionHandle` method of the same name. - // TODO: requires the new subscription interface and WS protocol. - pub fn unsubscribe_then(self, _on_end: impl FnOnce(&M::EventContext) + Send + 'static) -> anyhow::Result<()> { - todo!() + /// Record that the subscription has been applied and return the callback to run. + /// The caller is responsible for calling the callback. + pub(crate) fn on_applied(&mut self) -> Option> { + let mut inner = self.inner.lock().unwrap(); + inner.on_applied() + } + + /// Record that the subscription has been applied and return the callback to run. + /// The caller is responsible for calling the callback. + pub(crate) fn on_ended(&mut self) -> Option> { + let mut inner = self.inner.lock().unwrap(); + inner.on_ended() + } + + /// Record that the subscription has errored and return the callback to run. + /// The caller is responsible for calling the callback. + pub(crate) fn on_error(&mut self) -> Option> { + let mut inner = self.inner.lock().unwrap(); + inner.on_error() } } diff --git a/crates/sdk/tests/connect_disconnect_client/src/main.rs b/crates/sdk/tests/connect_disconnect_client/src/main.rs index 1370c21f00a..6cf85ba793e 100644 --- a/crates/sdk/tests/connect_disconnect_client/src/main.rs +++ b/crates/sdk/tests/connect_disconnect_client/src/main.rs @@ -2,7 +2,7 @@ mod module_bindings; use module_bindings::*; -use spacetimedb_sdk::{DbContext, Table}; +use spacetimedb_sdk::{DbContext, Event, Table}; use test_counter::TestCounter; @@ -15,6 +15,7 @@ fn db_name_or_panic() -> String { fn main() { let disconnect_test_counter = TestCounter::new(); let disconnect_result = disconnect_test_counter.add_test("disconnect"); + let on_error_result = disconnect_test_counter.add_test("on_error"); let connect_test_counter = TestCounter::new(); let connected_result = connect_test_counter.add_test("on_connect"); @@ -27,7 +28,12 @@ fn main() { .on_connect(move |ctx, _, _| { connected_result(Ok(())); ctx.subscription_builder() - .on_error(|ctx| panic!("Subscription failed: {:?}", ctx.event)) + .on_error(|ctx| { + if !matches!(ctx.event, Event::Disconnected) { + panic!("Subscription failed: {:?}", ctx.event) + } + on_error_result(Ok(())); + }) .on_applied(move |ctx| { let check = || { anyhow::ensure!(ctx.db.connected().count() == 1); @@ -40,7 +46,7 @@ fn main() { }; sub_applied_one_row_result(check()); }) - .subscribe(["SELECT * FROM connected;"]); + .subscribe("SELECT * FROM connected"); }) .on_disconnect(move |ctx, err| { assert!( @@ -94,7 +100,7 @@ fn main() { sub_applied_one_row_result(check()); }) .on_error(|ctx| panic!("subscription on_error: {:?}", ctx.event)) - .subscribe(["SELECT * FROM disconnected;"]); + .subscribe("SELECT * FROM disconnected"); new_connection.run_threaded(); diff --git a/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_connected_reducer.rs b/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_connected_reducer.rs index 13002531ccb..cbde34263f6 100644 --- a/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_connected_reducer.rs +++ b/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_connected_reducer.rs @@ -24,17 +24,17 @@ impl __sdk::InModule for IdentityConnectedArgs { pub struct IdentityConnectedCallbackId(__sdk::CallbackId); #[allow(non_camel_case_types)] -/// Extension trait for access to the reducer `__identity_connected__`. +/// Extension trait for access to the reducer `identity_connected`. /// /// Implemented for [`super::RemoteReducers`]. pub trait identity_connected { - /// Request that the remote module invoke the reducer `__identity_connected__` to run as soon as possible. + /// Request that the remote module invoke the reducer `identity_connected` to run as soon as possible. /// /// This method returns immediately, and errors only if we are unable to send the request. /// The reducer will run asynchronously in the future, /// and its status can be observed by listening for [`Self::on_identity_connected`] callbacks. fn identity_connected(&self) -> __anyhow::Result<()>; - /// Register a callback to run whenever we are notified of an invocation of the reducer `__identity_connected__`. + /// Register a callback to run whenever we are notified of an invocation of the reducer `identity_connected`. /// /// The [`super::EventContext`] passed to the `callback` /// will always have [`__sdk::Event::Reducer`] as its `event`, @@ -55,15 +55,14 @@ pub trait identity_connected { impl identity_connected for super::RemoteReducers { fn identity_connected(&self) -> __anyhow::Result<()> { - self.imp - .call_reducer("__identity_connected__", IdentityConnectedArgs {}) + self.imp.call_reducer("identity_connected", IdentityConnectedArgs {}) } fn on_identity_connected( &self, mut callback: impl FnMut(&super::EventContext) + Send + 'static, ) -> IdentityConnectedCallbackId { IdentityConnectedCallbackId(self.imp.on_reducer( - "__identity_connected__", + "identity_connected", Box::new(move |ctx: &super::EventContext| { let super::EventContext { event: @@ -81,19 +80,19 @@ impl identity_connected for super::RemoteReducers { )) } fn remove_on_identity_connected(&self, callback: IdentityConnectedCallbackId) { - self.imp.remove_on_reducer("__identity_connected__", callback.0) + self.imp.remove_on_reducer("identity_connected", callback.0) } } #[allow(non_camel_case_types)] #[doc(hidden)] -/// Extension trait for setting the call-flags for the reducer `__identity_connected__`. +/// Extension trait for setting the call-flags for the reducer `identity_connected`. /// /// Implemented for [`super::SetReducerFlags`]. /// /// This type is currently unstable and may be removed without a major version bump. pub trait set_flags_for_identity_connected { - /// Set the call-reducer flags for the reducer `__identity_connected__` to `flags`. + /// Set the call-reducer flags for the reducer `identity_connected` to `flags`. /// /// This type is currently unstable and may be removed without a major version bump. fn identity_connected(&self, flags: __ws::CallReducerFlags); @@ -101,6 +100,6 @@ pub trait set_flags_for_identity_connected { impl set_flags_for_identity_connected for super::SetReducerFlags { fn identity_connected(&self, flags: __ws::CallReducerFlags) { - self.imp.set_call_reducer_flags("__identity_connected__", flags); + self.imp.set_call_reducer_flags("identity_connected", flags); } } diff --git a/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_disconnected_reducer.rs b/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_disconnected_reducer.rs index 41dc34ff8e3..2b740d11f6e 100644 --- a/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_disconnected_reducer.rs +++ b/crates/sdk/tests/connect_disconnect_client/src/module_bindings/identity_disconnected_reducer.rs @@ -24,17 +24,17 @@ impl __sdk::InModule for IdentityDisconnectedArgs { pub struct IdentityDisconnectedCallbackId(__sdk::CallbackId); #[allow(non_camel_case_types)] -/// Extension trait for access to the reducer `__identity_disconnected__`. +/// Extension trait for access to the reducer `identity_disconnected`. /// /// Implemented for [`super::RemoteReducers`]. pub trait identity_disconnected { - /// Request that the remote module invoke the reducer `__identity_disconnected__` to run as soon as possible. + /// Request that the remote module invoke the reducer `identity_disconnected` to run as soon as possible. /// /// This method returns immediately, and errors only if we are unable to send the request. /// The reducer will run asynchronously in the future, /// and its status can be observed by listening for [`Self::on_identity_disconnected`] callbacks. fn identity_disconnected(&self) -> __anyhow::Result<()>; - /// Register a callback to run whenever we are notified of an invocation of the reducer `__identity_disconnected__`. + /// Register a callback to run whenever we are notified of an invocation of the reducer `identity_disconnected`. /// /// The [`super::EventContext`] passed to the `callback` /// will always have [`__sdk::Event::Reducer`] as its `event`, @@ -56,14 +56,14 @@ pub trait identity_disconnected { impl identity_disconnected for super::RemoteReducers { fn identity_disconnected(&self) -> __anyhow::Result<()> { self.imp - .call_reducer("__identity_disconnected__", IdentityDisconnectedArgs {}) + .call_reducer("identity_disconnected", IdentityDisconnectedArgs {}) } fn on_identity_disconnected( &self, mut callback: impl FnMut(&super::EventContext) + Send + 'static, ) -> IdentityDisconnectedCallbackId { IdentityDisconnectedCallbackId(self.imp.on_reducer( - "__identity_disconnected__", + "identity_disconnected", Box::new(move |ctx: &super::EventContext| { let super::EventContext { event: @@ -81,19 +81,19 @@ impl identity_disconnected for super::RemoteReducers { )) } fn remove_on_identity_disconnected(&self, callback: IdentityDisconnectedCallbackId) { - self.imp.remove_on_reducer("__identity_disconnected__", callback.0) + self.imp.remove_on_reducer("identity_disconnected", callback.0) } } #[allow(non_camel_case_types)] #[doc(hidden)] -/// Extension trait for setting the call-flags for the reducer `__identity_disconnected__`. +/// Extension trait for setting the call-flags for the reducer `identity_disconnected`. /// /// Implemented for [`super::SetReducerFlags`]. /// /// This type is currently unstable and may be removed without a major version bump. pub trait set_flags_for_identity_disconnected { - /// Set the call-reducer flags for the reducer `__identity_disconnected__` to `flags`. + /// Set the call-reducer flags for the reducer `identity_disconnected` to `flags`. /// /// This type is currently unstable and may be removed without a major version bump. fn identity_disconnected(&self, flags: __ws::CallReducerFlags); @@ -101,6 +101,6 @@ pub trait set_flags_for_identity_disconnected { impl set_flags_for_identity_disconnected for super::SetReducerFlags { fn identity_disconnected(&self, flags: __ws::CallReducerFlags) { - self.imp.set_call_reducer_flags("__identity_disconnected__", flags); + self.imp.set_call_reducer_flags("identity_disconnected", flags); } } diff --git a/crates/sdk/tests/connect_disconnect_client/src/module_bindings/mod.rs b/crates/sdk/tests/connect_disconnect_client/src/module_bindings/mod.rs index 94806d96908..c808e0590de 100644 --- a/crates/sdk/tests/connect_disconnect_client/src/module_bindings/mod.rs +++ b/crates/sdk/tests/connect_disconnect_client/src/module_bindings/mod.rs @@ -44,8 +44,8 @@ impl __sdk::InModule for Reducer { impl __sdk::Reducer for Reducer { fn reducer_name(&self) -> &'static str { match self { - Reducer::IdentityConnected => "__identity_connected__", - Reducer::IdentityDisconnected => "__identity_disconnected__", + Reducer::IdentityConnected => "identity_connected", + Reducer::IdentityDisconnected => "identity_disconnected", } } } @@ -53,13 +53,16 @@ impl TryFrom<__ws::ReducerCallInfo<__ws::BsatnFormat>> for Reducer { type Error = __anyhow::Error; fn try_from(value: __ws::ReducerCallInfo<__ws::BsatnFormat>) -> __anyhow::Result { match &value.reducer_name[..] { - "__identity_connected__" => Ok(__sdk::parse_reducer_args::< - identity_connected_reducer::IdentityConnectedArgs, - >("__identity_connected__", &value.args)? - .into()), - "__identity_disconnected__" => Ok(__sdk::parse_reducer_args::< + "identity_connected" => Ok( + __sdk::parse_reducer_args::( + "identity_connected", + &value.args, + )? + .into(), + ), + "identity_disconnected" => Ok(__sdk::parse_reducer_args::< identity_disconnected_reducer::IdentityDisconnectedArgs, - >("__identity_disconnected__", &value.args)? + >("identity_disconnected", &value.args)? .into()), _ => Err(__anyhow::anyhow!("Unknown reducer {:?}", value.reducer_name)), } @@ -375,6 +378,7 @@ impl __sdk::EventContext for EventContext { /// A handle on a subscribed query. // TODO: Document this better after implementing the new subscription API. +#[derive(Clone)] pub struct SubscriptionHandle { imp: __sdk::SubscriptionHandleImpl, } @@ -387,6 +391,26 @@ impl __sdk::SubscriptionHandle for SubscriptionHandle { fn new(imp: __sdk::SubscriptionHandleImpl) -> Self { Self { imp } } + + /// Returns true if this subscription has been terminated due to an unsubscribe call or an error. + fn is_ended(&self) -> bool { + self.imp.is_ended() + } + + /// Returns true if this subscription has been applied and has not yet been unsubscribed. + fn is_active(&self) -> bool { + self.imp.is_active() + } + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`, + /// then run `on_end` when its rows are removed from the client cache. + fn unsubscribe_then(self, on_end: __sdk::OnEndedCallback) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(Some(on_end)) + } + + fn unsubscribe(self) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(None) + } } /// Alias trait for a [`__sdk::DbContext`] connected to this module, diff --git a/crates/sdk/tests/test-client/src/main.rs b/crates/sdk/tests/test-client/src/main.rs index a34c05cd650..5da3353678f 100644 --- a/crates/sdk/tests/test-client/src/main.rs +++ b/crates/sdk/tests/test-client/src/main.rs @@ -3,12 +3,13 @@ mod module_bindings; use core::fmt::Display; +use std::sync::{atomic::AtomicUsize, Arc, Mutex}; use module_bindings::*; use spacetimedb_sdk::{ credentials, i256, u256, unstable::CallReducerFlags, Address, DbConnectionBuilder, DbContext, Event, Identity, - ReducerEvent, Status, Table, + ReducerEvent, Status, SubscriptionHandle, Table, }; use test_counter::TestCounter; @@ -53,6 +54,9 @@ fn main() { match &*test { "insert_primitive" => exec_insert_primitive(), + "subscribe_and_cancel" => exec_subscribe_and_cancel(), + "subscribe_and_unsubscribe" => exec_subscribe_and_unsubscribe(), + "subscription_error_smoke_test" => exec_subscription_error_smoke_test(), "delete_primitive" => exec_delete_primitive(), "update_primitive" => exec_update_primitive(), @@ -355,10 +359,106 @@ fn connect(test_counter: &std::sync::Arc) -> DbConnection { } fn subscribe_all_then(ctx: &impl RemoteDbContext, callback: impl FnOnce(&EventContext) + Send + 'static) { - ctx.subscription_builder() - .on_applied(callback) - .on_error(|ctx| panic!("Subscription errored: {:?}", ctx.event)) - .subscribe(SUBSCRIBE_ALL); + let remaining_queries = Arc::new(AtomicUsize::new(SUBSCRIBE_ALL.len())); + let callback = Arc::new(Mutex::new(Some(callback))); + for query in SUBSCRIBE_ALL { + let atomic = remaining_queries.clone(); + let callback = callback.clone(); + + let on_applied = move |ctx: &EventContext| { + let count = atomic.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); + if count == 1 { + // Only execute callback when the last subscription completes + if let Some(cb) = callback.lock().unwrap().take() { + cb(ctx); + } + } + }; + + ctx.subscription_builder() + .on_applied(on_applied) + .on_error(|ctx| panic!("Subscription errored: {:?}", ctx.event)) + .subscribe(query); + } +} + +fn exec_subscribe_and_cancel() { + let test_counter = TestCounter::new(); + let cb = test_counter.add_test("unsubscribe_then_called"); + connect_then(&test_counter, { + move |ctx| { + let handle = ctx + .subscription_builder() + .on_applied(move |_ctx: &EventContext| { + panic!("Subscription should never be applied"); + }) + .on_error(|ctx| panic!("Subscription errored: {:?}", ctx.event)) + .subscribe("SELECT * FROM one_u8;"); + assert!(!handle.is_active()); + assert!(!handle.is_ended()); + let handle_clone = handle.clone(); + handle + .unsubscribe_then(Box::new(move |_| { + assert!(!handle_clone.is_active()); + assert!(handle_clone.is_ended()); + cb(Ok(())); + })) + .unwrap(); + } + }); + test_counter.wait_for_all(); +} + +fn exec_subscribe_and_unsubscribe() { + let test_counter = TestCounter::new(); + let cb = test_counter.add_test("unsubscribe_then_called"); + connect_then(&test_counter, { + move |ctx| { + let handle_cell: Arc>> = Arc::new(Mutex::new(None)); + let hc_clone = handle_cell.clone(); + let handle = ctx + .subscription_builder() + .on_applied(move |ctx: &EventContext| { + let handle = { hc_clone.lock().unwrap().as_ref().unwrap().clone() }; + assert!(ctx.is_active()); + assert!(handle.is_active()); + assert!(!handle.is_ended()); + let handle_clone = handle.clone(); + handle + .unsubscribe_then(Box::new(move |_| { + assert!(!handle_clone.is_active()); + assert!(handle_clone.is_ended()); + cb(Ok(())); + })) + .unwrap(); + }) + .on_error(|ctx| panic!("Subscription errored: {:?}", ctx.event)) + .subscribe("SELECT * FROM one_u8;"); + handle_cell.lock().unwrap().replace(handle.clone()); + assert!(!handle.is_active()); + assert!(!handle.is_ended()); + } + }); + test_counter.wait_for_all(); +} + +fn exec_subscription_error_smoke_test() { + let test_counter = TestCounter::new(); + let cb = test_counter.add_test("error_callback_is_called"); + connect_then(&test_counter, { + move |ctx| { + let handle = ctx + .subscription_builder() + .on_applied(move |_ctx: &EventContext| { + panic!("Subscription should never be applied"); + }) + .on_error(|_| cb(Ok(()))) + .subscribe("SELEcCT * FROM one_u8;"); // intentional typo + assert!(!handle.is_active()); + assert!(!handle.is_ended()); + } + }); + test_counter.wait_for_all(); } /// This tests that we can: @@ -1671,30 +1771,27 @@ fn exec_caller_alice_receives_reducer_callback_but_not_bob() { let sub_applied = pre_ins_counter.add_test(format!("sub_applied_{who}")); let counter2 = counter.clone(); - conn.subscription_builder() - .on_applied(move |ctx| { - sub_applied(Ok(())); - - // Test that we are notified when a row is inserted. - let db = ctx.db(); - let mut one_u8_inserted = Some(counter2.add_test(format!("one_u8_inserted_{who}"))); - db.one_u_8().on_insert(move |_, row| { - (one_u8_inserted.take().unwrap())(check_val(row.n, 42)); - }); - let mut one_u16_inserted = Some(counter2.add_test(format!("one_u16_inserted_{who}"))); - db.one_u_16().on_insert(move |event, row| { - let run_checks = || { - anyhow::ensure!( - matches!(event.event, Event::UnknownTransaction), - "reducer should be unknown", - ); - check_val(row.n, 24) - }; - (one_u16_inserted.take().unwrap())(run_checks()); - }); - }) - .on_error(|_| panic!("Subscription error")) - .subscribe(["SELECT * FROM one_u8", "SELECT * FROM one_u16"]); + subscribe_all_then(&conn, move |ctx| { + sub_applied(Ok(())); + + // Test that we are notified when a row is inserted. + let db = ctx.db(); + let mut one_u8_inserted = Some(counter2.add_test(format!("one_u8_inserted_{who}"))); + db.one_u_8().on_insert(move |_, row| { + (one_u8_inserted.take().unwrap())(check_val(row.n, 42)); + }); + let mut one_u16_inserted = Some(counter2.add_test(format!("one_u16_inserted_{who}"))); + db.one_u_16().on_insert(move |event, row| { + let run_checks = || { + anyhow::ensure!( + matches!(event.event, Event::UnknownTransaction), + "reducer should be unknown", + ); + check_val(row.n, 24) + }; + (one_u16_inserted.take().unwrap())(run_checks()); + }); + }); conn }); diff --git a/crates/sdk/tests/test-client/src/module_bindings/mod.rs b/crates/sdk/tests/test-client/src/module_bindings/mod.rs index 2b3883fca11..70b9321516b 100644 --- a/crates/sdk/tests/test-client/src/module_bindings/mod.rs +++ b/crates/sdk/tests/test-client/src/module_bindings/mod.rs @@ -3284,6 +3284,7 @@ impl __sdk::EventContext for EventContext { /// A handle on a subscribed query. // TODO: Document this better after implementing the new subscription API. +#[derive(Clone)] pub struct SubscriptionHandle { imp: __sdk::SubscriptionHandleImpl, } @@ -3296,6 +3297,26 @@ impl __sdk::SubscriptionHandle for SubscriptionHandle { fn new(imp: __sdk::SubscriptionHandleImpl) -> Self { Self { imp } } + + /// Returns true if this subscription has been terminated due to an unsubscribe call or an error. + fn is_ended(&self) -> bool { + self.imp.is_ended() + } + + /// Returns true if this subscription has been applied and has not yet been unsubscribed. + fn is_active(&self) -> bool { + self.imp.is_active() + } + + /// Unsubscribe from the query controlled by this `SubscriptionHandle`, + /// then run `on_end` when its rows are removed from the client cache. + fn unsubscribe_then(self, on_end: __sdk::OnEndedCallback) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(Some(on_end)) + } + + fn unsubscribe(self) -> __anyhow::Result<()> { + self.imp.unsubscribe_then(None) + } } /// Alias trait for a [`__sdk::DbContext`] connected to this module, diff --git a/crates/sdk/tests/test.rs b/crates/sdk/tests/test.rs index 64bcc8d1b1a..65278b75ab2 100644 --- a/crates/sdk/tests/test.rs +++ b/crates/sdk/tests/test.rs @@ -23,6 +23,20 @@ macro_rules! declare_tests_with_suffix { make_test("insert_primitive").run(); } + #[test] + fn subscribe_and_cancel() { + make_test("subscribe_and_cancel").run(); + } + + #[test] + fn subscribe_and_unsubscribe() { + make_test("subscribe_and_unsubscribe").run(); + } + + #[test] + fn subscription_error_smoke_test() { + make_test("subscription_error_smoke_test").run(); + } #[test] fn delete_primitive() { make_test("delete_primitive").run(); diff --git a/crates/standalone/Dockerfile b/crates/standalone/Dockerfile index 5457d838eae..713848409e9 100644 --- a/crates/standalone/Dockerfile +++ b/crates/standalone/Dockerfile @@ -32,7 +32,7 @@ RUN mkdir -p /stdb/data && ln -s /usr/src/app/crates/standalone/config.toml /std ENV PATH="/usr/src/app/target/debug:${PATH}" FROM debian as env-release -COPY --from=builder /usr/src/app/target/release/spacetimedb /usr/local/bin/ +COPY --from=builder /usr/src/app/target/release/spacetimedb-standalone /usr/local/bin/ COPY --from=builder /usr/src/app/crates/standalone/config.toml /stdb/data/config.toml FROM env-${CARGO_PROFILE} @@ -40,5 +40,5 @@ FROM env-${CARGO_PROFILE} EXPOSE 3000 ENV RUST_BACKTRACE=1 -ENTRYPOINT ["spacetimedb", "--data-dir=/stdb/data", "--jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub", "--jwt-priv-key-path=/etc/spacetimedb/id_ecdsa"] +ENTRYPOINT ["spacetimedb-standalone", "--data-dir=/stdb/data", "--jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub", "--jwt-priv-key-path=/etc/spacetimedb/id_ecdsa"] CMD ["start"] diff --git a/crates/table/benches/page_manager.rs b/crates/table/benches/page_manager.rs index eaa16b507a7..8ea1bd0e114 100644 --- a/crates/table/benches/page_manager.rs +++ b/crates/table/benches/page_manager.rs @@ -728,7 +728,8 @@ fn make_table_with_index(unique: bool) -> (Table, IndexId) { let cols = R::indexed_columns(); let index_id = IndexId::SENTINEL; let idx = tbl.new_index(cols, unique).unwrap(); - tbl.insert_index(&NullBlobStore, index_id, idx); + // SAFETY: index was derived from the table. + unsafe { tbl.insert_index(&NullBlobStore, index_id, idx) }; (tbl, index_id) } diff --git a/crates/table/src/blob_store.rs b/crates/table/src/blob_store.rs index 3431b5b35cf..52078072ba6 100644 --- a/crates/table/src/blob_store.rs +++ b/crates/table/src/blob_store.rs @@ -104,6 +104,27 @@ pub trait BlobStore: Sync { /// /// Used when capturing a snapshot. fn iter_blobs(&self) -> BlobsIter<'_>; + + /// Returns the amount of memory in bytes used by blobs in this `BlobStore`. + /// + /// Duplicate blobs are counted a number of times equal to their refcount. + /// This is in order to preserve the property that inserting a large blob + /// causes this quantity to increase by that blob's size, + /// and deleting a large blob causes it to decrease the same amount. + fn bytes_used_by_blobs(&self) -> u64 { + self.iter_blobs() + .map(|(_, uses, data)| data.len() as u64 * uses as u64) + .sum() + } + + /// Returns the number of blobs, or more precisely, blob-usages, recorded in this `BlobStore`. + /// + /// Duplicate blobs are counted a number of times equal to their refcount. + /// This is in order to preserve the property that inserting a large blob + /// causes this quantity to increase by 1, and deleting a large blob causes it to decrease by 1. + fn num_blobs(&self) -> u64 { + self.iter_blobs().map(|(_, uses, _)| uses as u64).sum() + } } /// A blob store that panics on all operations. diff --git a/crates/table/src/btree_index.rs b/crates/table/src/btree_index.rs index 34c939ee849..aa36749287f 100644 --- a/crates/table/src/btree_index.rs +++ b/crates/table/src/btree_index.rs @@ -30,9 +30,12 @@ use spacetimedb_sats::{ algebraic_value::Packed, i256, product_value::InvalidFieldError, u256, AlgebraicType, AlgebraicValue, ProductType, }; +mod key_size; mod multimap; mod uniquemap; +pub use key_size::KeySize; + type Index = multimap::MultiMap; type IndexIter<'a, K> = multimap::MultiMapRangeIter<'a, K, RowPointer>; type UniqueIndex = uniquemap::UniqueMap; @@ -315,35 +318,71 @@ impl TypedIndex { /// Add the row referred to by `row_ref` to the index `self`, /// which must be keyed at `cols`. /// - /// If `cols` is inconsistent with `self`, - /// or the `row_ref` has a row type other than that used for `self`, - /// this will behave oddly; it may return an error, - /// or may insert a nonsense value into the index. - /// Note, however, that it will not invoke undefined behavior. + /// The returned `usize` is the number of bytes used by the key. + /// [`BTreeIndex::check_and_insert`] will use this + /// to update the counter for [`BTreeIndex::num_key_bytes`]. + /// We want to store said counter outside of the [`TypedIndex`] enum, + /// but we can only compute the size using type info within the [`TypedIndex`], + /// so we have to return the size across this boundary. /// - /// Returns `Ok(Some(existing_row))` if this index was a unique index that was violated. + /// Returns `Errs(existing_row)` if this index was a unique index that was violated. /// The index is not inserted to in that case. - fn insert(&mut self, cols: &ColList, row_ref: RowRef<'_>) -> Result, InvalidFieldError> { - fn mm_insert_at_type( + /// + /// # Safety + /// + /// 1. Caller promises that `cols` matches what was given at construction (`Self::new`). + /// 2. Caller promises that the projection of `row_ref`'s type's equals the index's key type. + unsafe fn insert(&mut self, cols: &ColList, row_ref: RowRef<'_>) -> Result { + fn project_to_singleton_key(cols: &ColList, row_ref: RowRef<'_>) -> T { + // Extract the column. + let col_pos = cols.as_singleton(); + // SAFETY: Caller promised that `cols` matches what was given at construction (`Self::new`). + // In the case of `.clone_structure()`, the structure is preserved, + // so the promise is also preserved. + // This entails, that because we reached here, that `cols` is singleton. + let col_pos = unsafe { col_pos.unwrap_unchecked() }.idx(); + + // Extract the layout of the column. + let col_layouts = &row_ref.row_layout().product().elements; + // SAFETY: + // - Caller promised that projecting the `row_ref`'s type/layout to `self.indexed_columns` + // gives us the index's key type. + // This entails that each `ColId` in `self.indexed_columns` + // must be in-bounds of `row_ref`'s layout. + let col_layout = unsafe { col_layouts.get_unchecked(col_pos) }; + + // Read the column in `row_ref`. + // SAFETY: + // - `col_layout` was just derived from the row layout. + // - Caller promised that the type-projection of the row type/layout + // equals the index's key type. + // We've reached here, so the index's key type is compatible with `T`. + // - `self` is a valid row so offsetting to `col_layout` is valid. + unsafe { T::unchecked_read_column(row_ref, col_layout) } + } + + fn mm_insert_at_type( this: &mut Index, cols: &ColList, row_ref: RowRef<'_>, - ) -> Result, InvalidFieldError> { - let col_pos = cols.as_singleton().unwrap(); - let key = row_ref.read_col(col_pos).map_err(|_| col_pos)?; + ) -> Result { + let key: T = project_to_singleton_key(cols, row_ref); + let key_size = key.key_size_in_bytes(); this.insert(key, row_ref.pointer()); - Ok(None) + Ok(key_size) } - fn um_insert_at_type( + fn um_insert_at_type( this: &mut UniqueIndex, cols: &ColList, row_ref: RowRef<'_>, - ) -> Result, InvalidFieldError> { - let col_pos = cols.as_singleton().unwrap(); - let key = row_ref.read_col(col_pos).map_err(|_| col_pos)?; - Ok(this.insert(key, row_ref.pointer()).copied()) + ) -> Result { + let key: T = project_to_singleton_key(cols, row_ref); + let key_size = key.key_size_in_bytes(); + this.insert(key, row_ref.pointer()) + .map_err(|ptr| *ptr) + .map(|_| key_size) } - let unique_violation = match self { + match self { Self::Bool(idx) => mm_insert_at_type(idx, cols, row_ref), Self::U8(idx) => mm_insert_at_type(idx, cols, row_ref), Self::I8(idx) => mm_insert_at_type(idx, cols, row_ref), @@ -359,9 +398,11 @@ impl TypedIndex { Self::I256(idx) => mm_insert_at_type(idx, cols, row_ref), Self::String(idx) => mm_insert_at_type(idx, cols, row_ref), Self::AV(this) => { - let key = row_ref.project(cols)?; + // SAFETY: Caller promised that any `col` in `cols` is in-bounds of `row_ref`'s layout. + let key = unsafe { row_ref.project_unchecked(cols) }; + let key_size = key.key_size_in_bytes(); this.insert(key, row_ref.pointer()); - Ok(None) + Ok(key_size) } Self::UniqueBool(idx) => um_insert_at_type(idx, cols, row_ref), Self::UniqueU8(idx) => um_insert_at_type(idx, cols, row_ref), @@ -378,11 +419,14 @@ impl TypedIndex { Self::UniqueI256(idx) => um_insert_at_type(idx, cols, row_ref), Self::UniqueString(idx) => um_insert_at_type(idx, cols, row_ref), Self::UniqueAV(this) => { - let key = row_ref.project(cols)?; - Ok(this.insert(key, row_ref.pointer()).copied()) + // SAFETY: Caller promised that any `col` in `cols` is in-bounds of `row_ref`'s layout. + let key = unsafe { row_ref.project_unchecked(cols) }; + let key_size = key.key_size_in_bytes(); + this.insert(key, row_ref.pointer()) + .map_err(|ptr| *ptr) + .map(|_| key_size) } - }?; - Ok(unique_violation) + } } /// Remove the row referred to by `row_ref` from the index `self`, @@ -393,24 +437,34 @@ impl TypedIndex { /// this will behave oddly; it may return an error, do nothing, /// or remove the wrong value from the index. /// Note, however, that it will not invoke undefined behavior. - fn delete(&mut self, cols: &ColList, row_ref: RowRef<'_>) -> Result { - fn mm_delete_at_type( + /// + /// If the row was present and has been deleted, returns `Ok(Some(key_size_in_bytes))`, + /// where `key_size_in_bytes` is the size of the key. + /// [`BTreeIndex::delete`] will use this + /// to update the counter for [`BTreeIndex::num_key_bytes`]. + /// We want to store said counter outside of the [`TypedIndex`] enum, + /// but we can only compute the size using type info within the [`TypedIndex`], + /// so we have to return the size across this boundary. + fn delete(&mut self, cols: &ColList, row_ref: RowRef<'_>) -> Result, InvalidFieldError> { + fn mm_delete_at_type( this: &mut Index, cols: &ColList, row_ref: RowRef<'_>, - ) -> Result { + ) -> Result, InvalidFieldError> { let col_pos = cols.as_singleton().unwrap(); - let key = row_ref.read_col(col_pos).map_err(|_| col_pos)?; - Ok(this.delete(&key, &row_ref.pointer())) + let key: T = row_ref.read_col(col_pos).map_err(|_| col_pos)?; + let key_size = key.key_size_in_bytes(); + Ok(this.delete(&key, &row_ref.pointer()).then_some(key_size)) } - fn um_delete_at_type( + fn um_delete_at_type( this: &mut UniqueIndex, cols: &ColList, row_ref: RowRef<'_>, - ) -> Result { + ) -> Result, InvalidFieldError> { let col_pos = cols.as_singleton().unwrap(); - let key = row_ref.read_col(col_pos).map_err(|_| col_pos)?; - Ok(this.delete(&key)) + let key: T = row_ref.read_col(col_pos).map_err(|_| col_pos)?; + let key_size = key.key_size_in_bytes(); + Ok(this.delete(&key).then_some(key_size)) } match self { @@ -430,7 +484,8 @@ impl TypedIndex { Self::String(this) => mm_delete_at_type(this, cols, row_ref), Self::AV(this) => { let key = row_ref.project(cols)?; - Ok(this.delete(&key, &row_ref.pointer())) + let key_size = key.key_size_in_bytes(); + Ok(this.delete(&key, &row_ref.pointer()).then_some(key_size)) } Self::UniqueBool(this) => um_delete_at_type(this, cols, row_ref), Self::UniqueU8(this) => um_delete_at_type(this, cols, row_ref), @@ -448,7 +503,8 @@ impl TypedIndex { Self::UniqueString(this) => um_delete_at_type(this, cols, row_ref), Self::UniqueAV(this) => { let key = row_ref.project(cols)?; - Ok(this.delete(&key)) + let key_size = key.key_size_in_bytes(); + Ok(this.delete(&key).then_some(key_size)) } } } @@ -633,7 +689,20 @@ pub struct BTreeIndex { idx: TypedIndex, /// The key type of this index. /// This is the projection of the row type to the types of the columns indexed. + // TODO(perf, bikeshedding): Could trim `sizeof(BTreeIndex)` to 64 if this was `Box`. pub key_type: AlgebraicType, + + /// The number of rows in this index. + /// + /// Memoized counter for [`Self::num_rows`]. + num_rows: u64, + + /// The number of key bytes in this index. + /// + /// Memoized counter for [`Self::num_key_bytes`]. + /// See that method for more detailed documentation. + num_key_bytes: u64, + /// Given a full row, typed at some `ty: ProductType`, /// these columns are the ones that this index indexes. /// Projecting the `ty` to `self.indexed_columns` yields the index's type `self.key_type`. @@ -645,13 +714,19 @@ impl MemoryUsage for BTreeIndex { let Self { idx, key_type, + num_rows, + num_key_bytes, indexed_columns, } = self; - idx.heap_usage() + key_type.heap_usage() + indexed_columns.heap_usage() + idx.heap_usage() + + key_type.heap_usage() + + num_rows.heap_usage() + + num_key_bytes.heap_usage() + + indexed_columns.heap_usage() } } -static_assert_size!(BTreeIndex, 64); +static_assert_size!(BTreeIndex, 80); impl BTreeIndex { /// Returns a new possibly unique index, with `index_id` for a set of columns. @@ -661,6 +736,8 @@ impl BTreeIndex { Ok(Self { idx: typed_index, key_type, + num_rows: 0, + num_key_bytes: 0, indexed_columns, }) } @@ -674,6 +751,8 @@ impl BTreeIndex { Self { idx, key_type, + num_rows: 0, + num_key_bytes: 0, indexed_columns, } } @@ -686,16 +765,46 @@ impl BTreeIndex { /// Inserts `ptr` with the value `row` to this index. /// This index will extract the necessary values from `row` based on `self.indexed_columns`. /// - /// Returns `Ok(Some(existing_row))` if this insertion would violate a unique constraint. - pub fn check_and_insert(&mut self, row_ref: RowRef<'_>) -> Result, InvalidFieldError> { - self.idx.insert(&self.indexed_columns, row_ref) + /// Returns `Err(existing_row)` if this insertion would violate a unique constraint. + /// + /// # Safety + /// + /// Caller promises that projecting the `row_ref`'s type + /// to the index's columns equals the index's key type. + /// This is entailed by an index belonging to the table's schema. + /// It also follows from `row_ref`'s type/layout + /// being the same as passed in on `self`'s construction. + pub unsafe fn check_and_insert(&mut self, row_ref: RowRef<'_>) -> Result<(), RowPointer> { + // SAFETY: + // 1. We're passing the same `ColList` that was provided during construction. + // 2. Forward the caller's proof obligation. + let res = unsafe { self.idx.insert(&self.indexed_columns, row_ref) }; + match res { + Ok(key_size) => { + // No existing row; the new row was inserted. + // Update the `num_rows` and `num_key_bytes` counters + // to account for the new insertion. + self.num_rows += 1; + self.num_key_bytes += key_size as u64; + Ok(()) + } + Err(e) => Err(e), + } } /// Deletes `row_ref` with its indexed value `row_ref.project(&self.indexed_columns)` from this index. /// - /// Returns whether `row_ref` was present. + /// Returns whether `ptr` was present. pub fn delete(&mut self, row_ref: RowRef<'_>) -> Result { - self.idx.delete(&self.indexed_columns, row_ref) + if let Some(size_in_bytes) = self.idx.delete(&self.indexed_columns, row_ref)? { + // Was present, and deleted: update the `num_rows` and `num_key_bytes` counters. + self.num_rows -= 1; + self.num_key_bytes -= size_in_bytes as u64; + Ok(true) + } else { + // Was not present: don't update counters. + Ok(false) + } } /// Returns whether `value` is in this index. @@ -724,16 +833,21 @@ impl BTreeIndex { /// Extends [`BTreeIndex`] with `rows`. /// /// Returns the first unique constraint violation caused when adding this index, if any. - pub fn build_from_rows<'table>( + /// + /// # Safety + /// + /// Caller promises that projecting any of the `row_ref`'s type + /// to the index's columns equals the index's key type. + /// This is entailed by an index belonging to the table's schema. + /// It also follows from `row_ref`'s type/layout + /// being the same as passed in on `self`'s construction. + pub unsafe fn build_from_rows<'table>( &mut self, rows: impl IntoIterator>, - ) -> Result, InvalidFieldError> { - for row_ref in rows { - if let violation @ Some(_) = self.check_and_insert(row_ref)? { - return Ok(violation); - } - } - Ok(None) + ) -> Result<(), RowPointer> { + rows.into_iter() + // SAFETY: Forward caller proof obligation. + .try_for_each(|row_ref| unsafe { self.check_and_insert(row_ref) }) } /// Deletes all entries from the index, leaving it empty. @@ -743,12 +857,35 @@ impl BTreeIndex { /// rather than constructing a new `BTreeIndex`. pub fn clear(&mut self) { self.idx.clear(); + self.num_key_bytes = 0; + self.num_rows = 0; } /// The number of unique keys in this index. pub fn num_keys(&self) -> usize { self.idx.num_keys() } + + /// The number of rows stored in this index. + /// + /// Note that, for non-unique indexes, this may be larger than [`Self::num_keys`]. + /// + /// This method runs in constant time. + pub fn num_rows(&self) -> u64 { + self.num_rows + } + + /// The number of bytes stored in keys in this index. + /// + /// For non-unique indexes, duplicate keys are counted once for each row that refers to them, + /// even though the internal storage may deduplicate them as an optimization. + /// + /// This method runs in constant time. + /// + /// See the [`KeySize`] trait for more details on how this method computes its result. + pub fn num_key_bytes(&self) -> u64 { + self.num_key_bytes + } } #[cfg(test)] @@ -828,7 +965,7 @@ mod test { prop_assert_eq!(index.idx.len(), 0); prop_assert_eq!(index.contains_any(&value), false); - prop_assert_eq!(index.check_and_insert(row_ref).unwrap(), None); + prop_assert_eq!(unsafe { index.check_and_insert(row_ref) }, Ok(())); prop_assert_eq!(index.idx.len(), 1); prop_assert_eq!(index.contains_any(&value), true); @@ -854,7 +991,8 @@ mod test { ); // Insert. - prop_assert_eq!(index.check_and_insert(row_ref).unwrap(), None); + // SAFETY: `row_ref` has the same type as was passed in when constructing `index`. + prop_assert_eq!(unsafe { index.check_and_insert(row_ref) }, Ok(())); // Inserting again would be a problem. prop_assert_eq!(index.idx.len(), 1); @@ -863,7 +1001,8 @@ mod test { get_rows_that_violate_unique_constraint(&index, &value).unwrap().collect::>(), [row_ref.pointer()] ); - prop_assert_eq!(index.check_and_insert(row_ref).unwrap(), Some(row_ref.pointer())); + // SAFETY: `row_ref` has the same type as was passed in when constructing `index`. + prop_assert_eq!(unsafe { index.check_and_insert(row_ref) }, Err(row_ref.pointer())); } #[test] @@ -887,7 +1026,8 @@ mod test { let row = product![x]; let row_ref = table.insert(&mut blob_store, &row).unwrap().1; val_to_ptr.insert(x, row_ref.pointer()); - prop_assert_eq!(index.check_and_insert(row_ref).unwrap(), None); + // SAFETY: `row_ref` has the same type as was passed in when constructing `index`. + prop_assert_eq!(unsafe { index.check_and_insert(row_ref) }, Ok(())); } fn test_seek(index: &BTreeIndex, val_to_ptr: &HashMap, range: impl RangeBounds, expect: impl IntoIterator) -> TestCaseResult { diff --git a/crates/table/src/btree_index/key_size.rs b/crates/table/src/btree_index/key_size.rs new file mode 100644 index 00000000000..2128341f8e6 --- /dev/null +++ b/crates/table/src/btree_index/key_size.rs @@ -0,0 +1,143 @@ +use spacetimedb_sats::{ + algebraic_value::Packed, i256, u256, AlgebraicValue, ArrayValue, ProductValue, SumValue, F32, F64, +}; + +/// Index keys whose memory usage we can measure and report. +/// +/// The reported memory usage of an index is based on: +/// +/// - the number of entries in that index, i.e. the number of `RowPointer`s it stores, +/// - the total size of the keys for every entry in that index. +/// +/// This trait is used to measure the latter. +/// The metric we measure, sometimes called "data size," +/// is the number of live user-supplied bytes in the key. +/// This excludes padding and lengths, though it does include sum tags. +/// +/// The key size of a value is defined depending on that value's type: +/// - Integer, float and boolean values take bytes according to their [`std::mem::size_of`]. +/// - Strings take bytes equal to their length in bytes. +/// No overhead is counted, unlike in the BFLATN or BSATN size. +/// - Sum values take 1 byte for the tag, plus the bytes of their active payload. +/// Inactive variants and padding are not counted, unlike in the BFLATN size. +/// - Product values take bytes equal to the sum of their elements' bytes. +/// Padding is not counted, unlike in the BFLATN size. +/// - Array values take bytes equal to the sum of their elements' bytes. +/// As with strings, no overhead is counted. +pub trait KeySize { + fn key_size_in_bytes(&self) -> usize; +} + +macro_rules! impl_key_size_primitive { + ($prim:ty) => { + impl KeySize for $prim { + fn key_size_in_bytes(&self) -> usize { std::mem::size_of::() } + } + }; + ($($prim:ty,)*) => { + $(impl_key_size_primitive!($prim);)* + }; +} + +impl_key_size_primitive!( + bool, + u8, + i8, + u16, + i16, + u32, + i32, + u64, + i64, + u128, + i128, + Packed, + Packed, + u256, + i256, + F32, + F64, +); + +impl KeySize for Box { + fn key_size_in_bytes(&self) -> usize { + self.len() + } +} + +impl KeySize for AlgebraicValue { + fn key_size_in_bytes(&self) -> usize { + match self { + AlgebraicValue::Bool(x) => x.key_size_in_bytes(), + AlgebraicValue::U8(x) => x.key_size_in_bytes(), + AlgebraicValue::I8(x) => x.key_size_in_bytes(), + AlgebraicValue::U16(x) => x.key_size_in_bytes(), + AlgebraicValue::I16(x) => x.key_size_in_bytes(), + AlgebraicValue::U32(x) => x.key_size_in_bytes(), + AlgebraicValue::I32(x) => x.key_size_in_bytes(), + AlgebraicValue::U64(x) => x.key_size_in_bytes(), + AlgebraicValue::I64(x) => x.key_size_in_bytes(), + AlgebraicValue::U128(x) => x.key_size_in_bytes(), + AlgebraicValue::I128(x) => x.key_size_in_bytes(), + AlgebraicValue::U256(x) => x.key_size_in_bytes(), + AlgebraicValue::I256(x) => x.key_size_in_bytes(), + AlgebraicValue::F32(x) => x.key_size_in_bytes(), + AlgebraicValue::F64(x) => x.key_size_in_bytes(), + AlgebraicValue::String(x) => x.key_size_in_bytes(), + AlgebraicValue::Sum(x) => x.key_size_in_bytes(), + AlgebraicValue::Product(x) => x.key_size_in_bytes(), + AlgebraicValue::Array(x) => x.key_size_in_bytes(), + + AlgebraicValue::Min | AlgebraicValue::Max => unreachable!(), + } + } +} + +impl KeySize for SumValue { + fn key_size_in_bytes(&self) -> usize { + 1 + self.value.key_size_in_bytes() + } +} + +impl KeySize for ProductValue { + fn key_size_in_bytes(&self) -> usize { + self.elements.key_size_in_bytes() + } +} + +impl KeySize for [K] +where + K: KeySize, +{ + // TODO(perf, bikeshedding): check that this optimized to `size_of::() * self.len()` + // when `K` is a primitive. + fn key_size_in_bytes(&self) -> usize { + self.iter().map(|elt| elt.key_size_in_bytes()).sum() + } +} + +impl KeySize for ArrayValue { + fn key_size_in_bytes(&self) -> usize { + match self { + ArrayValue::Sum(elts) => elts.key_size_in_bytes(), + ArrayValue::Product(elts) => elts.key_size_in_bytes(), + ArrayValue::Bool(elts) => elts.key_size_in_bytes(), + ArrayValue::I8(elts) => elts.key_size_in_bytes(), + ArrayValue::U8(elts) => elts.key_size_in_bytes(), + ArrayValue::I16(elts) => elts.key_size_in_bytes(), + ArrayValue::U16(elts) => elts.key_size_in_bytes(), + ArrayValue::I32(elts) => elts.key_size_in_bytes(), + ArrayValue::U32(elts) => elts.key_size_in_bytes(), + ArrayValue::I64(elts) => elts.key_size_in_bytes(), + ArrayValue::U64(elts) => elts.key_size_in_bytes(), + ArrayValue::I128(elts) => elts.key_size_in_bytes(), + ArrayValue::U128(elts) => elts.key_size_in_bytes(), + ArrayValue::I256(elts) => elts.key_size_in_bytes(), + ArrayValue::U256(elts) => elts.key_size_in_bytes(), + ArrayValue::F32(elts) => elts.key_size_in_bytes(), + ArrayValue::F64(elts) => elts.key_size_in_bytes(), + ArrayValue::String(elts) => elts.key_size_in_bytes(), + ArrayValue::Array(elts) => elts.key_size_in_bytes(), + } + } +} diff --git a/crates/table/src/btree_index/uniquemap.rs b/crates/table/src/btree_index/uniquemap.rs index 01de3758bf1..be3f0e0abd7 100644 --- a/crates/table/src/btree_index/uniquemap.rs +++ b/crates/table/src/btree_index/uniquemap.rs @@ -29,13 +29,13 @@ impl UniqueMap { /// /// If `key` was already present in the map, does not add an association with `val`. /// Returns the existing associated value instead. - pub fn insert(&mut self, key: K, val: V) -> Option<&V> { + pub fn insert(&mut self, key: K, val: V) -> Result<(), &V> { match self.map.entry(key) { Entry::Vacant(e) => { e.insert(val); - None + Ok(()) } - Entry::Occupied(e) => Some(e.into_mut()), + Entry::Occupied(e) => Err(e.into_mut()), } } diff --git a/crates/table/src/page.rs b/crates/table/src/page.rs index ba4535db0b1..a734199feb3 100644 --- a/crates/table/src/page.rs +++ b/crates/table/src/page.rs @@ -248,6 +248,12 @@ struct VarHeader { /// pre-decrement this index. // TODO(perf,future-work): determine how to "lower" the high water mark when freeing the "top"-most granule. first: PageOffset, + + /// The number of granules currently used by rows within this page. + /// + /// [`Page::bytes_used_by_rows`] needs this information. + /// Stored here because otherwise counting it would require traversing all the present rows. + num_granules: u16, } impl MemoryUsage for VarHeader { @@ -256,12 +262,13 @@ impl MemoryUsage for VarHeader { next_free, freelist_len, first, + num_granules, } = self; - next_free.heap_usage() + freelist_len.heap_usage() + first.heap_usage() + next_free.heap_usage() + freelist_len.heap_usage() + first.heap_usage() + num_granules.heap_usage() } } -static_assert_size!(VarHeader, 6); +static_assert_size!(VarHeader, 8); impl Default for VarHeader { fn default() -> Self { @@ -269,6 +276,7 @@ impl Default for VarHeader { next_free: FreeCellRef::NIL, freelist_len: 0, first: PageOffset::PAGE_END, + num_granules: 0, } } } @@ -771,6 +779,8 @@ impl<'page> VarView<'page> { granule, ); + self.header.num_granules += 1; + Ok(granule) } @@ -812,6 +822,7 @@ impl<'page> VarView<'page> { // but we want to return a whole "run" of sequential freed chunks, // which requries some bookkeeping (or an O(> n) linked list traversal). self.header.freelist_len += 1; + self.header.num_granules -= 1; let adjuster = self.adjuster(); // SAFETY: Per caller contract, `offset` is a valid `VarLenGranule`, @@ -1113,10 +1124,94 @@ impl Page { } /// Returns the number of rows stored in this page. + /// + /// This method runs in constant time. pub fn num_rows(&self) -> usize { self.header.fixed.num_rows as usize } + #[cfg(test)] + /// Use this page's present rows bitvec to compute the number of present rows. + /// + /// This can be compared with [`Self::num_rows`] as a consistency check during tests. + pub fn reconstruct_num_rows(&self) -> usize { + // If we cared, we could rewrite this to `u64::count_ones` on each block of the bitset. + // We do not care. This method is slow. + self.header.fixed.present_rows.iter_set().count() + } + + /// Returns the number of var-len granules allocated in this page. + /// + /// This method runs in constant time. + pub fn num_var_len_granules(&self) -> usize { + self.header.var.num_granules as usize + } + + #[cfg(test)] + /// # Safety + /// + /// - `var_len_visitor` must be a valid [`VarLenMembers`] visitor + /// specialized to the type and layout of rows within this [`Page`]. + /// - `fixed_row_size` must be exactly the length in bytes of fixed rows in this page, + /// which must further be the length of rows expected by the `var_len_visitor`. + pub unsafe fn reconstruct_num_var_len_granules( + &self, + fixed_row_size: Size, + var_len_visitor: &impl VarLenMembers, + ) -> usize { + self.iter_fixed_len(fixed_row_size) + .flat_map(|row| unsafe { + // Safety: `row` came out of `iter_fixed_len`, + // which, due to caller requirements on `fixed_row_size`, + // is giving us valid, aligned, initialized rows of the row type. + var_len_visitor.visit_var_len(self.get_row_data(row, fixed_row_size)) + }) + .flat_map(|var_len_obj| unsafe { + // Safety: We believe `row` to be valid + // and `var_len_visitor` to be correctly visiting its var-len members. + // Therefore, `var_len_obj` is a valid var-len object. + self.iter_var_len_object(var_len_obj.first_granule) + }) + .count() + } + + /// Returns the number of bytes used by rows stored in this page. + /// + /// This is necessarily an overestimate of live data bytes, as it includes: + /// - Padding bytes within the fixed-length portion of the rows. + /// - [`VarLenRef`] pointer-like portions of rows. + /// - Unused trailing parts of partially-filled [`VarLenGranule`]s. + /// - [`VarLenGranule`]s used to store [`BlobHash`]es. + /// + /// Note that large blobs themselves are not counted. + /// The caller should obtain a count of the bytes used by large blobs + /// from the [`super::blob_store::BlobStore`]. + /// + /// This method runs in constant time. + pub fn bytes_used_by_rows(&self, fixed_row_size: Size) -> usize { + let fixed_row_bytes = self.num_rows() * fixed_row_size.len(); + let var_len_bytes = self.num_var_len_granules() * VarLenGranule::SIZE.len(); + fixed_row_bytes + var_len_bytes + } + + #[cfg(test)] + /// # Safety + /// + /// - `var_len_visitor` must be a valid [`VarLenMembers`] visitor + /// specialized to the type and layout of rows within this [`Page`]. + /// - `fixed_row_size` must be exactly the length in bytes of fixed rows in this page, + /// which must further be the length of rows expected by the `var_len_visitor`. + pub unsafe fn reconstruct_bytes_used_by_rows( + &self, + fixed_row_size: Size, + var_len_visitor: &impl VarLenMembers, + ) -> usize { + let fixed_row_bytes = self.reconstruct_num_rows() * fixed_row_size.len(); + let var_len_bytes = unsafe { self.reconstruct_num_var_len_granules(fixed_row_size, var_len_visitor) } + * VarLenGranule::SIZE.len(); + fixed_row_bytes + var_len_bytes + } + /// Returns the range of row data starting at `offset` and lasting `size` bytes. pub fn get_row_data(&self, row: PageOffset, size: Size) -> &Bytes { &self.row_data[row.range(size)] diff --git a/crates/table/src/table.rs b/crates/table/src/table.rs index c573b6b5466..d7a7874d403 100644 --- a/crates/table/src/table.rs +++ b/crates/table/src/table.rs @@ -251,14 +251,21 @@ impl Table { /// returns true if and only if that row should be ignored. /// While checking unique constraints against the committed state, /// `MutTxId::insert` will ignore rows which are listed in the delete table. - pub fn check_unique_constraints<'a, I: Iterator>( + /// + /// # Safety + /// + /// `row.row_layout() == self.row_layout()` must hold. + pub unsafe fn check_unique_constraints<'a, I: Iterator>( &'a self, row: RowRef<'_>, adapt: impl FnOnce(btree_map::Iter<'a, IndexId, BTreeIndex>) -> I, mut is_deleted: impl FnMut(RowPointer) -> bool, ) -> Result<(), UniqueConstraintViolation> { for (&index_id, index) in adapt(self.indexes.iter()).filter(|(_, index)| index.is_unique()) { - let value = row.project(&index.indexed_columns).unwrap(); + // SAFETY: Caller promised that `row´ has the same layout as `self`. + // Thus, as `index.indexed_columns` is in-bounds of `self`'s layout, + // it's also in-bounds of `row`'s layout. + let value = unsafe { row.project_unchecked(&index.indexed_columns) }; if index.seek(&value).next().is_some_and(|ptr| !is_deleted(ptr)) { return Err(self.build_error_unique(index, index_id, value)); } @@ -569,7 +576,9 @@ impl Table { for (&index_id, index) in self.indexes.iter_mut() { // SAFETY: We just inserted `ptr`, so it must be present. let row_ref = unsafe { self.inner.get_row_ref_unchecked(blob_store, ptr) }; - if index.check_and_insert(row_ref).unwrap().is_some() { + // SAFETY: any index in this table was constructed with the same row type as this table. + let violation = unsafe { index.check_and_insert(row_ref) }; + if violation.is_err() { let cols = &index.indexed_columns; let value = row_ref.project(cols).unwrap(); let error = UniqueConstraintViolation::build(&self.schema, index, index_id, value); @@ -1020,18 +1029,30 @@ impl Table { /// /// # Panics /// - /// Panics if `index.indexed_columns` has some column that is out of bounds of the table's row layout. - /// Also panics if any row would violate `index`'s unique constraint, if it has one. - pub fn insert_index(&mut self, blob_store: &dyn BlobStore, index_id: IndexId, mut index: BTreeIndex) { - index - .build_from_rows(self.scan_rows(blob_store)) - .expect("`cols` should consist of valid columns for this table") - .inspect(|ptr| panic!("adding `index` should cause no unique constraint violations, but {ptr:?} would")); - self.add_index(index_id, index); + /// Panics if any row would violate `index`'s unique constraint, if it has one. + /// + /// # Safety + /// + /// Caller must promise that `index` was constructed with the same row type/layout as this table. + pub unsafe fn insert_index(&mut self, blob_store: &dyn BlobStore, index_id: IndexId, mut index: BTreeIndex) { + let rows = self.scan_rows(blob_store); + // SAFETY: Caller promised that table's row type/layout + // matches that which `index` was constructed with. + // It follows that this applies to any `rows`, as required. + let violation = unsafe { index.build_from_rows(rows) }; + violation.unwrap_or_else(|ptr| { + panic!("adding `index` should cause no unique constraint violations, but {ptr:?} would") + }); + // SAFETY: Forward caller requirement. + unsafe { self.add_index(index_id, index) }; } /// Adds an index to the table without populating. - pub fn add_index(&mut self, index_id: IndexId, index: BTreeIndex) { + /// + /// # Safety + /// + /// Caller must promise that `index` was constructed with the same row type/layout as this table. + pub unsafe fn add_index(&mut self, index_id: IndexId, index: BTreeIndex) { let is_unique = index.is_unique(); self.indexes.insert(index_id, index); @@ -1158,6 +1179,75 @@ impl Table { self.compute_row_count(blob_store); self.rebuild_pointer_map(blob_store); } + + /// Returns the number of rows resident in this table. + /// + /// This scales in runtime with the number of pages in the table. + pub fn num_rows(&self) -> u64 { + self.pages().iter().map(|page| page.num_rows() as u64).sum() + } + + #[cfg(test)] + fn reconstruct_num_rows(&self) -> u64 { + self.pages().iter().map(|page| page.reconstruct_num_rows() as u64).sum() + } + + /// Returns the number of bytes used by rows resident in this table. + /// + /// This includes data bytes, padding bytes and some overhead bytes, + /// as described in the docs for [`Page::bytes_used_by_rows`], + /// but *does not* include: + /// + /// - Unallocated space within pages. + /// - Per-page overhead (e.g. page headers). + /// - Table overhead (e.g. the [`RowTypeLayout`], [`PointerMap`], [`Schema`] &c). + /// - Indexes. + /// - Large blobs in the [`BlobStore`]. + /// + /// Of these, the caller should inspect the blob store in order to account for memory usage by large blobs, + /// and call [`Self::bytes_used_by_index_keys`] to account for indexes, + /// but we intend to eat all the other overheads when billing. + pub fn bytes_used_by_rows(&self) -> u64 { + self.pages() + .iter() + .map(|page| page.bytes_used_by_rows(self.inner.row_layout.size()) as u64) + .sum() + } + + #[cfg(test)] + fn reconstruct_bytes_used_by_rows(&self) -> u64 { + self.pages() + .iter() + .map(|page| unsafe { + // Safety: `page` is in `self`, and was constructed using `self.innser.row_layout` and `self.inner.visitor_prog`, + // so the three are mutually consistent. + page.reconstruct_bytes_used_by_rows(self.inner.row_layout.size(), &self.inner.visitor_prog) + } as u64) + .sum() + } + + /// Returns the number of rows (or [`RowPointer`]s, more accurately) + /// stored in indexes by this table. + /// + /// This method runs in constant time. + pub fn num_rows_in_indexes(&self) -> u64 { + // Assume that each index contains all rows in the table. + self.num_rows() * self.indexes.len() as u64 + } + + /// Returns the number of bytes used by keys stored in indexes by this table. + /// + /// This method scales in runtime with the number of indexes in the table, + /// but not with the number of pages or rows. + /// + /// Key size is measured using a metric called "key size" or "data size," + /// which is intended to capture the number of live user-supplied bytes, + /// not including representational overhead. + /// This is distinct from the BFLATN size measured by [`Self::bytes_used_by_rows`]. + /// See the trait [`crate::btree_index::KeySize`] for specifics on the metric measured. + pub fn bytes_used_by_index_keys(&self) -> u64 { + self.indexes.values().map(|idx| idx.num_key_bytes()).sum() + } } /// A reference to a single row within a table. @@ -1230,6 +1320,41 @@ impl<'a> RowRef<'a> { T::read_column(self, col.into().idx()) } + /// Construct a projection of the row at `self` by extracting the `cols`. + /// + /// If `cols` contains zero or more than one column, the values of the projected columns are wrapped in a [`ProductValue`]. + /// If `cols` is a single column, the value of that column is returned without wrapping in a `ProductValue`. + /// + /// # Safety + /// + /// - `cols` must not specify any column which is out-of-bounds for the row `self´. + pub unsafe fn project_unchecked(self, cols: &ColList) -> AlgebraicValue { + let col_layouts = &self.row_layout().product().elements; + + if let Some(head) = cols.as_singleton() { + let head = head.idx(); + // SAFETY: caller promised that `head` is in-bounds of `col_layouts`. + let col_layout = unsafe { col_layouts.get_unchecked(head) }; + // SAFETY: + // - `col_layout` was just derived from the row layout. + // - `AlgebraicValue` is compatible with any `col_layout`. + // - `self` is a valid row and offsetting to `col_layout` is valid. + return unsafe { AlgebraicValue::unchecked_read_column(self, col_layout) }; + } + let mut elements = Vec::with_capacity(cols.len() as usize); + for col in cols.iter() { + let col = col.idx(); + // SAFETY: caller promised that any `col` is in-bounds of `col_layouts`. + let col_layout = unsafe { col_layouts.get_unchecked(col) }; + // SAFETY: + // - `col_layout` was just derived from the row layout. + // - `AlgebraicValue` is compatible with any `col_layout`. + // - `self` is a valid row and offsetting to `col_layout` is valid. + elements.push(unsafe { AlgebraicValue::unchecked_read_column(self, col_layout) }); + } + AlgebraicValue::product(elements) + } + /// Construct a projection of the row at `self` by extracting the `cols`. /// /// Returns an error if `cols` specifies an index which is out-of-bounds for the row at `self`. @@ -1556,6 +1681,7 @@ pub struct UniqueConstraintViolation { impl UniqueConstraintViolation { /// Returns a unique constraint violation error for the given `index` /// and the `value` that would have been duplicated. + #[cold] fn build(schema: &TableSchema, index: &BTreeIndex, index_id: IndexId, value: AlgebraicValue) -> Self { // Fetch the table name. let table_name = schema.table_name.clone(); @@ -1589,6 +1715,7 @@ impl UniqueConstraintViolation { impl Table { /// Returns a unique constraint violation error for the given `index` /// and the `value` that would have been duplicated. + #[cold] pub fn build_error_unique( &self, index: &BTreeIndex, @@ -1732,7 +1859,7 @@ pub(crate) mod test { use spacetimedb_lib::db::raw_def::v9::{RawIndexAlgorithm, RawModuleDefV9Builder}; use spacetimedb_primitives::{col_list, TableId}; use spacetimedb_sats::bsatn::to_vec; - use spacetimedb_sats::proptest::generate_typed_row; + use spacetimedb_sats::proptest::{generate_typed_row, generate_typed_row_vec}; use spacetimedb_sats::{product, AlgebraicType, ArrayValue}; use spacetimedb_schema::def::ModuleDef; use spacetimedb_schema::schema::Schema as _; @@ -1772,7 +1899,8 @@ pub(crate) mod test { let cols = ColList::new(0.into()); let index = table.new_index(cols.clone(), true).unwrap(); - table.insert_index(&NullBlobStore, index_schema.index_id, index); + // SAFETY: Index was derived from `table`. + unsafe { table.insert_index(&NullBlobStore, index_schema.index_id, index) }; // Reserve a page so that we can check the hash. let pi = table.inner.pages.reserve_empty_page(table.row_size()).unwrap(); @@ -1850,6 +1978,93 @@ pub(crate) mod test { insert_retrieve_body(ty, AlgebraicValue::from(arr)).unwrap(); } + fn reconstruct_index_num_key_bytes(table: &Table, blob_store: &dyn BlobStore, index_id: IndexId) -> u64 { + let index = table.get_index_by_id(index_id).unwrap(); + + index + .seek(&(..)) + .map(|row_ptr| { + let row_ref = table.get_row_ref(blob_store, row_ptr).unwrap(); + let key = row_ref.project(&index.indexed_columns).unwrap(); + crate::btree_index::KeySize::key_size_in_bytes(&key) as u64 + }) + .sum() + } + + /// Given a row type `ty`, a set of rows of that type `vals`, + /// and a set of columns within that type `indexed_columns`, + /// populate a table with `vals`, add an index on the `indexed_columns`, + /// and perform various assertions that the reported index size metrics are correct. + fn test_index_size_reporting( + ty: ProductType, + vals: Vec, + indexed_columns: ColList, + ) -> Result<(), TestCaseError> { + let mut blob_store = HashMapBlobStore::default(); + let mut table = table(ty.clone()); + + for row in &vals { + prop_assume!(table.insert(&mut blob_store, row).is_ok()); + } + + // We haven't added any indexes yet, so there should be 0 rows in indexes. + prop_assert_eq!(table.num_rows_in_indexes(), 0); + + let index_id = IndexId(0); + + // Add an index on column 0. + // Safety: + // We're using `ty` as the row type for both `table` and the new index. + unsafe { + table.insert_index( + &blob_store, + index_id, + BTreeIndex::new(&ty, indexed_columns.clone(), false).unwrap(), + ); + } + + // We have one index, which should be fully populated, + // so in total we should have the same number of rows in indexes as we have rows. + prop_assert_eq!(table.num_rows_in_indexes(), table.num_rows()); + + let index = table.get_index_by_id(index_id).unwrap(); + + // One index, so table's reporting of bytes used should match that index's reporting. + prop_assert_eq!(table.bytes_used_by_index_keys(), index.num_key_bytes()); + + // Walk all the rows in the index, sum their key size, + // and assert it matches the `index.num_key_bytes()` + prop_assert_eq!( + index.num_key_bytes(), + reconstruct_index_num_key_bytes(&table, &blob_store, index_id) + ); + + // Walk all the rows we inserted, project them to the cols that will be their keys, + // sum their key size, + // and assert it matches the `index.num_key_bytes()` + let key_size_in_pvs = vals + .iter() + .map(|row| crate::btree_index::KeySize::key_size_in_bytes(&row.project(&indexed_columns).unwrap()) as u64) + .sum(); + prop_assert_eq!(index.num_key_bytes(), key_size_in_pvs); + + // Add a duplicate of the same index, so we can check that all above quantities double. + // Safety: + // As above, we're using `ty` as the row type for both `table` and the new index. + unsafe { + table.insert_index( + &blob_store, + IndexId(1), + BTreeIndex::new(&ty, indexed_columns, false).unwrap(), + ); + } + + prop_assert_eq!(table.num_rows_in_indexes(), table.num_rows() * 2); + prop_assert_eq!(table.bytes_used_by_index_keys(), key_size_in_pvs * 2); + + Ok(()) + } + proptest! { #![proptest_config(ProptestConfig { max_shrink_iters: 0x10000000, ..Default::default() })] @@ -1938,6 +2153,39 @@ pub(crate) mod test { prop_assert_eq!(bs_pv, bs_bsatn); prop_assert_eq!(table_pv, table_bsatn); } + + #[test] + fn row_size_reporting_matches_slow_implementations((ty, vals) in generate_typed_row_vec(128, 2048)) { + let mut blob_store = HashMapBlobStore::default(); + let mut table = table(ty.clone()); + + for row in &vals { + prop_assume!(table.insert(&mut blob_store, row).is_ok()); + } + + prop_assert_eq!(table.bytes_used_by_rows(), table.reconstruct_bytes_used_by_rows()); + prop_assert_eq!(table.num_rows(), table.reconstruct_num_rows()); + prop_assert_eq!(table.num_rows(), vals.len() as u64); + + // TODO(testing): Determine if there's a meaningful way to test that the blob store reporting is correct. + // I (pgoldman 2025-01-27) doubt it, as the test would be "visit every blob and sum their size," + // which is already what the actual implementation does. + } + + #[test] + fn index_size_reporting_matches_slow_implementations_single_column((ty, vals) in generate_typed_row_vec(128, 2048)) { + prop_assume!(!ty.elements.is_empty()); + + test_index_size_reporting(ty, vals, ColList::from(ColId(0)))?; + } + + #[test] + fn index_size_reporting_matches_slow_implementations_two_column((ty, vals) in generate_typed_row_vec(128, 2048)) { + prop_assume!(ty.elements.len() >= 2); + + + test_index_size_reporting(ty, vals, ColList::from([ColId(0), ColId(1)]))?; + } } fn insert_bsatn<'a>( diff --git a/smoketests/tests/panic.py b/smoketests/tests/panic.py index 87aa482d192..0fd917340b3 100644 --- a/smoketests/tests/panic.py +++ b/smoketests/tests/panic.py @@ -29,4 +29,22 @@ def test_panic(self): self.call("first") self.call("second") - self.assertIn("Test Passed", self.logs(2)) \ No newline at end of file + self.assertIn("Test Passed", self.logs(2)) + +class ReducerError(Smoketest): + MODULE_CODE = """ +use spacetimedb::ReducerContext; + +#[spacetimedb::reducer] +fn fail(_ctx: &ReducerContext) -> Result<(), String> { + Err("oopsie :(".into()) +} +""" + + def test_reducer_error_message(self): + """Tests to ensure an error message returned from a reducer gets printed to logs""" + + with self.assertRaises(Exception): + self.call("fail") + + self.assertIn("oopsie :(", self.logs(2))