Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add concurrency option to control max concurrent writes to db #460

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/src/db_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ use uriparse::URI;
pub struct DataSourceParams<'a> {
pub uri: URI<'a>,
pub schema: Option<String>, // PostgreSQL
pub concurrency: usize,
}
1 change: 1 addition & 0 deletions core/src/schema/content/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ impl Compile for DatasourceContent {
let params = DataSourceParams {
uri: URI::try_from(self.path.as_str())?,
schema: None,
concurrency: 1,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really sure what this is used for and what the default concurrency should be here

};
let iter = get_iter(params).map(|i| -> Box<dyn Iterator<Item = Value>> {
if !self.cycle {
Expand Down
1 change: 1 addition & 0 deletions synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ strsim = "0.10.0"

async-std = { version = "1.12", features = [ "attributes", "unstable" ] }
async-trait = "0.1.50"
async-lock = "2.6"
futures = "0.3.15"

fs2 = "0.4.3"
Expand Down
1 change: 1 addition & 0 deletions synth/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn bench_generate_n_to_stdout(size: usize) {
seed: Some(0),
random: false,
schema: None,
concurrency: 3,
});
let output = io::stdout();
Cli::new().unwrap().run(args, output).await.unwrap()
Expand Down
3 changes: 3 additions & 0 deletions synth/src/cli/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ where
"postgres" | "postgresql" => Box::new(PostgresExportStrategy {
uri_string: params.uri.to_string(),
schema: params.schema,
concurrency: params.concurrency,
}),
"mongodb" => Box::new(MongoExportStrategy {
uri_string: params.uri.to_string(),
concurrency: params.concurrency,
}),
"mysql" | "mariadb" => Box::new(MySqlExportStrategy {
uri_string: params.uri.to_string(),
concurrency: params.concurrency,
}),
"json" => {
if params.uri.path() == "" {
Expand Down
9 changes: 9 additions & 0 deletions synth/src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ impl<'w> Cli {
uri: URI::try_from(cmd.from.as_str())
.with_context(|| format!("Parsing import URI '{}'", cmd.from))?,
schema: cmd.schema,
concurrency: 1,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've hardcoded import jobs to concurrency 1 since I think we only use a single db connection here, but let me know if it should also be exposed in the import command.

}
.try_into()?;

Expand Down Expand Up @@ -176,6 +177,7 @@ impl<'w> Cli {
uri: URI::try_from(cmd.to.as_str())
.with_context(|| format!("Parsing generation URI '{}'", cmd.to))?,
schema: cmd.schema,
concurrency: cmd.concurrency,
}
.try_into()?;

Expand Down Expand Up @@ -289,6 +291,13 @@ pub struct GenerateCommand {
)]
#[serde(skip)]
pub schema: Option<String>,
#[structopt(
long,
help = "The maximum number of concurrent tasks writing to the database.",
default_value = "3"
)]
#[serde(skip)]
pub concurrency: usize,
}

#[derive(StructOpt, Serialize)]
Expand Down
8 changes: 7 additions & 1 deletion synth/src/cli/mongo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use synth_core::{Content, Namespace, Value};
#[derive(Clone, Debug)]
pub struct MongoExportStrategy {
pub uri_string: String,
pub concurrency: usize,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -152,7 +153,12 @@ fn bson_to_content(bson: &Bson) -> Content {

impl ExportStrategy for MongoExportStrategy {
fn export(&self, _namespace: Namespace, sample: SamplerOutput) -> Result<()> {
let mut client = Client::with_uri_str(&self.uri_string)?;
let mut client_options = ClientOptions::parse(&self.uri_string)?;
client_options.max_pool_size = Some(self.concurrency.try_into().unwrap());

info!("Connecting to database at {} ...", &self.uri_string);

let mut client = Client::with_options(client_options)?;

match sample {
SamplerOutput::Collection(name, value) => {
Expand Down
16 changes: 13 additions & 3 deletions synth/src/cli/mysql.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::cli::export::{create_and_insert_values, ExportStrategy};
use crate::cli::import::ImportStrategy;
use crate::cli::import_utils::build_namespace_import;
use crate::datasource::mysql_datasource::MySqlDataSource;
use crate::datasource::mysql_datasource::{MySqlConnectParams, MySqlDataSource};
use crate::datasource::DataSource;
use crate::sampler::SamplerOutput;
use anyhow::Result;
Expand All @@ -10,11 +10,17 @@ use synth_core::schema::Namespace;
#[derive(Clone, Debug)]
pub struct MySqlExportStrategy {
pub uri_string: String,
pub concurrency: usize,
}

impl ExportStrategy for MySqlExportStrategy {
fn export(&self, _namespace: Namespace, sample: SamplerOutput) -> Result<()> {
let datasource = MySqlDataSource::new(&self.uri_string)?;
let connect_params = MySqlConnectParams {
uri: self.uri_string.clone(),
concurrency: self.concurrency,
};

let datasource = MySqlDataSource::new(&connect_params)?;

create_and_insert_values(sample, &datasource)
}
Expand All @@ -27,7 +33,11 @@ pub struct MySqlImportStrategy {

impl ImportStrategy for MySqlImportStrategy {
fn import(&self) -> Result<Namespace> {
let datasource = MySqlDataSource::new(&self.uri_string)?;
let connect_params = MySqlConnectParams {
uri: self.uri_string.clone(),
concurrency: 1,
};
let datasource = MySqlDataSource::new(&connect_params)?;

build_namespace_import(&datasource)
}
Expand Down
3 changes: 3 additions & 0 deletions synth/src/cli/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ use synth_core::schema::Namespace;
pub struct PostgresExportStrategy {
pub uri_string: String,
pub schema: Option<String>,
pub concurrency: usize,
}

impl ExportStrategy for PostgresExportStrategy {
fn export(&self, _namespace: Namespace, sample: SamplerOutput) -> Result<()> {
let connect_params = PostgresConnectParams {
uri: self.uri_string.clone(),
schema: self.schema.clone(),
concurrency: self.concurrency,
};

let datasource = PostgresDataSource::new(&connect_params)?;
Expand All @@ -37,6 +39,7 @@ impl ImportStrategy for PostgresImportStrategy {
let connect_params = PostgresConnectParams {
uri: self.uri_string.clone(),
schema: self.schema.clone(),
concurrency: 1,
};

let datasource = PostgresDataSource::new(&connect_params)?;
Expand Down
28 changes: 24 additions & 4 deletions synth/src/datasource/mysql_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,31 @@ use synth_gen::prelude::*;
/// - MySql aliases bool and boolean data types as tinyint. We currently define all tinyint as i8.
/// Ideally, the user can define a way to force certain fields as bool rather than i8.

pub struct MySqlConnectParams {
pub(crate) uri: String,
pub(crate) concurrency: usize,
}

pub struct MySqlDataSource {
pool: Pool<MySql>,
concurrency: usize,
}

#[async_trait]
impl DataSource for MySqlDataSource {
type ConnectParams = String;
type ConnectParams = MySqlConnectParams;

fn new(connect_params: &Self::ConnectParams) -> Result<Self> {
task::block_on(async {
let pool = MySqlPoolOptions::new()
.max_connections(3) //TODO expose this as a user config?
.connect(connect_params.as_str())
.max_connections(connect_params.concurrency.try_into().unwrap())
.connect(connect_params.uri.as_str())
.await?;

Ok::<Self, anyhow::Error>(MySqlDataSource { pool })
Ok::<Self, anyhow::Error>(MySqlDataSource {
pool,
concurrency: connect_params.concurrency,
})
})
}

Expand All @@ -53,6 +62,17 @@ impl SqlxDataSource for MySqlDataSource {

const IDENTIFIER_QUOTE: char = '`';

fn clone(&self) -> Self {
Self {
pool: Pool::clone(&self.pool),
concurrency: self.concurrency,
}
}

fn get_concurrency(&self) -> usize {
self.concurrency
}

fn get_pool(&self) -> Pool<Self::DB> {
Pool::clone(&self.pool)
}
Expand Down
18 changes: 17 additions & 1 deletion synth/src/datasource/postgres_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ use synth_core::{Content, Value};
pub struct PostgresConnectParams {
pub(crate) uri: String,
pub(crate) schema: Option<String>,
pub(crate) concurrency: usize,
}

pub struct PostgresDataSource {
pool: Pool<Postgres>,
single_thread_pool: Pool<Postgres>,
schema: String, // consider adding a type schema
concurrency: usize,
}

#[async_trait]
Expand All @@ -42,7 +44,7 @@ impl DataSource for PostgresDataSource {

let mut arc = Arc::new(schema.clone());
let pool = PgPoolOptions::new()
.max_connections(3) //TODO expose this as a user config?
.max_connections(connect_params.concurrency.try_into().unwrap())
.after_connect(move |conn, _meta| {
let schema = arc.clone();
Box::pin(async move {
Expand Down Expand Up @@ -76,6 +78,7 @@ impl DataSource for PostgresDataSource {
pool,
single_thread_pool,
schema,
concurrency: connect_params.concurrency,
})
})
}
Expand Down Expand Up @@ -119,6 +122,19 @@ impl SqlxDataSource for PostgresDataSource {

const IDENTIFIER_QUOTE: char = '\"';

fn clone(&self) -> Self {
Self {
pool: Pool::clone(&self.pool),
single_thread_pool: Pool::clone(&self.single_thread_pool),
schema: self.schema.clone(),
concurrency: self.concurrency,
}
}

fn get_concurrency(&self) -> usize {
self.concurrency
}

fn get_pool(&self) -> Pool<Self::DB> {
Pool::clone(&self.single_thread_pool)
}
Expand Down
37 changes: 31 additions & 6 deletions synth/src/datasource/relational_datasource.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use crate::datasource::DataSource;
use anyhow::Result;
use async_lock::Semaphore;
use async_std::task;
use async_trait::async_trait;
use beau_collector::BeauCollector;
use futures::future::join_all;
use sqlx::{
query::Query, Arguments, Connection, Database, Encode, Executor, IntoArguments, Pool, Type,
};
use std::sync::Arc;
use synth_core::{Content, Value};
use synth_gen::value::Number;

Expand Down Expand Up @@ -52,12 +55,17 @@ pub trait SqlxDataSource: DataSource {

const IDENTIFIER_QUOTE: char;

fn clone(&self) -> Self;

/// Gets a pool to execute queries with
fn get_pool(&self) -> Pool<Self::DB>;

/// Gets a multithread pool to execute queries with
fn get_multithread_pool(&self) -> Pool<Self::DB>;

/// Get the maximum concurrency for the data source
fn get_concurrency(&self) -> usize;

/// Prepare a single query with data source specifics
fn query<'q>(&self, query: &'q str) -> Query<'q, Self::DB, Self::Arguments> {
sqlx::query(query)
Expand Down Expand Up @@ -117,8 +125,8 @@ pub trait SqlxDataSource: DataSource {
) -> Result<<Self::DB as Database>::QueryResult>
where
for<'c> &'c mut Self::Connection: Executor<'c, Database = Self::DB>,
Value: Type<Self::DB>,
for<'d> Value: Encode<'d, Self::DB>,
Value: Type<Self::DB> + Send + Sync,
for<'d> Value: Encode<'d, Self::DB> + Send + Sync,
{
let mut query = sqlx::query::<Self::DB>(query.as_str());

Expand All @@ -132,7 +140,7 @@ pub trait SqlxDataSource: DataSource {
}
}

pub async fn insert_relational_data<T: SqlxDataSource + Sync>(
pub async fn insert_relational_data<T: SqlxDataSource + Sync + Send + 'static>(
datasource: &T,
collection_name: &str,
collection: &[Value],
Expand All @@ -146,6 +154,7 @@ where
for<'d> Value: Encode<'d, T::DB>,
{
let batch_size = DEFAULT_INSERT_BATCH_SIZE;
let max_concurrency = datasource.get_concurrency();

if collection.is_empty() {
println!("Collection {collection_name} generated 0 values. Skipping insertion...",);
Expand Down Expand Up @@ -208,9 +217,19 @@ where
.collect::<Vec<String>>()
.join(",");

let mut futures = Vec::with_capacity(collection.len());
let collection_chunks = collection.chunks(batch_size);
let mut futures = Vec::with_capacity(collection_chunks.len());

info!(
"Inserting {} rows for {}...",
collection.len(),
collection_name
);

let semaphore = Arc::new(Semaphore::new(max_concurrency));
for rows in collection_chunks {
let permit = semaphore.clone().acquire_arc().await;

for rows in collection.chunks(batch_size) {
let table_name = datasource.get_table_name_for_insert(collection_name);
let mut query = format!("INSERT INTO {table_name} ({column_names}) VALUES \n");

Expand All @@ -233,7 +252,13 @@ where
query.push_str(",\n");
}
}
let future = datasource.execute_query(query, query_params);
let datasource = datasource.clone();
let future = task::spawn(async move {
let result = datasource.execute_query(query, query_params).await;
drop(permit);
result
});

futures.push(future);
}

Expand Down
1 change: 1 addition & 0 deletions synth/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub async fn generate_scenario(namespace: &str, scenario: Option<String>) -> Res
seed: Some(5),
size: 10,
to: "json:".to_string(),
concurrency: 3,
}))
.await
}
Expand Down