Skip to content

Commit

Permalink
feat: use dlc-macros for async dlc-manager
Browse files Browse the repository at this point in the history
  • Loading branch information
bennyhodl committed Oct 8, 2024
1 parent 0d16b4b commit b970675
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 39 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"dlc-messages",
"dlc-trie",
"dlc-manager",
"dlc-macros",
"mocks",
"sample",
"simple-wallet",
Expand Down
21 changes: 21 additions & 0 deletions dlc-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "dlc-macros"
authors = ["benny b <[email protected]>"]
version = "0.1.0"
edition = "2018"
description = "Procedural macros for writing optionally asynchronous code for traits and functions."
homepage = "https://github.com/p2pderivatives/rust-dlc"
license-file = "../LICENSE"
repository = "https://github.com/p2pderivatives/rust-dlc/tree/master/dlc-macros"

[dependencies]
proc-macro2 = "1.0.87"
quote = "1.0.37"
syn = { version = "2.0.79", features = ["full", "extra-traits"] }
tokio = { version = "1.40.0", features = ["macros", "test-util"] }

[lib]
proc-macro = true

[dev-dependencies]
trybuild = "1.0.99"
190 changes: 190 additions & 0 deletions dlc-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
//! Procedural macros for writing optionally asynchronous code for traits and functions.
//! Inspires by [`bdk-macros`](https://github.com/bitcoindevkit/bdk/blob/v0.29.0/macros/src/lib.rs)
#![crate_name = "dlc_macros"]
// Coding conventions
#![forbid(unsafe_code)]
#![deny(non_upper_case_globals)]
#![deny(non_camel_case_types)]
#![deny(non_snake_case)]
#![deny(unused_mut)]
#![deny(dead_code)]
#![deny(unused_imports)]
#![deny(missing_docs)]

use proc_macro::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use syn::{
parse_macro_input, Attribute, Expr, ImplItem, Item, ItemFn, ItemImpl, ItemTrait, TraitItem,
};

// Check if the function attributes contains #[maybe_async].
// For conditional compilation of member functions.
fn is_maybe_async_attr(attr: &Attribute) -> bool {
// Check if the attribute path is exactly "maybe_async"
if attr.path().is_ident("maybe_async") {
return true;
}

// Check if the attribute path is of the form "module::maybe_async"
if let Some(last_segment) = attr.path().segments.last() {
return last_segment.ident == "maybe_async";
}
false
}

// Add async to a standalone function.
fn add_async_to_fn(mut func: ItemFn) -> TokenStream {
// For standalone functions, we'll always make them potentially async
let sync_version = func.clone();
func.sig.asyncness = Some(syn::Token![async](func.sig.span()));

quote! {
#[cfg(not(feature = "async"))]
#sync_version

#[cfg(feature = "async")]
#func
}
.into()
}

// Adds the `async_trait` macro to the trait and appends async to all of
// the member functions marked `#[maybe_async]`.
fn add_async_to_trait(mut trait_item: ItemTrait) -> TokenStream {
// Check if the trait itself has the `#[maybe_async]` attribute
let is_trait_async = trait_item.attrs.iter().any(is_maybe_async_attr);
trait_item.attrs.retain(|attr| !is_maybe_async_attr(attr)); // Remove the attribute from the trait

let mut async_trait_item = trait_item.clone();

for item in &mut async_trait_item.items {
if let TraitItem::Fn(method) = item {
if is_trait_async || method.attrs.iter().any(is_maybe_async_attr) {
method.sig.asyncness = Some(syn::Token![async](method.sig.span()));
method.attrs.retain(is_maybe_async_attr);
}
}
}

quote! {
#[cfg(not(feature = "async"))]
#trait_item

#[cfg(feature = "async")]
#[async_trait::async_trait]
#async_trait_item
}
.into()
}

// Adds async to a member of a struct implementation method.
fn add_async_to_impl(impl_item: ItemImpl) -> TokenStream {
let mut async_impl_item = impl_item.clone();

for item in &mut async_impl_item.items {
if let ImplItem::Fn(method) = item {
if method.attrs.iter().any(is_maybe_async_attr) {
method.sig.asyncness = Some(syn::Token![async](method.sig.span()));
method.attrs.retain(|attr| !is_maybe_async_attr(attr));
}
}
}

quote! {
#[cfg(not(feature = "async"))]
#impl_item

#[cfg(feature = "async")]
#[async_trait::async_trait]
#async_impl_item
}
.into()
}

/// Makes a method or every method of a trait `async`, if the `async` feature is enabled.
///
/// Requires the `async-trait` crate as a dependency whenever this attribute is used on a trait
/// definition or trait implementation.
#[proc_macro_attribute]
pub fn maybe_async(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as Item);

match input {
Item::Fn(func) => add_async_to_fn(func),
Item::Trait(trait_item) => add_async_to_trait(trait_item),
Item::Impl(impl_item) => add_async_to_impl(impl_item),
Item::Verbatim(verbatim) => {
// This case handles unexpected verbatim content, like doc comments
quote! {
#verbatim
}
.into()
}
other => {
let item_type = format!("{:?}", other);
let error_msg = format!(
"#[maybe_async] can only be used on functions, traits, or impl blocks, not on: {}",
item_type
);
quote! {
compile_error!(#error_msg);
}
.into()
}
}
}

/// Awaits, if the `async` feature is enabled.
#[proc_macro]
pub fn maybe_await(input: TokenStream) -> TokenStream {
let expr = parse_macro_input!(input as Expr);
let quoted = quote! {
{
#[cfg(not(feature = "async"))]
{
#expr
}

#[cfg(feature = "async")]
{
#expr.await
}
}
};

quoted.into()
}

/// Awaits, if the `async` feature is enabled, uses `tokio::Runtime::block_on()` otherwise
///
/// Requires the `tokio` crate as a dependecy with `rt-core` or `rt-threaded` to build.
#[proc_macro]
pub fn await_or_block(expr: TokenStream) -> TokenStream {
let expr = parse_macro_input!(expr as Expr);
let quoted = quote! {
{
#[cfg(not(feature = "async"))]
{
tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(#expr)
}

#[cfg(feature = "async")]
{
#expr.await
}
}
};

quoted.into()
}

#[cfg(test)]
mod tests {
#[test]
fn test_async_trait() {
let t = trybuild::TestCases::new();
t.pass("tests/sync.rs");
}
}
27 changes: 27 additions & 0 deletions dlc-macros/tests/maybe_async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use dlc_macros::*;

#[maybe_async]
trait TestTrait {
fn test_method(&self) -> Result<(), std::io::Error>;
}

struct TestStruct;

#[maybe_async]
impl TestTrait for TestStruct {
fn test_method(&self) -> Result<(), std::io::Error> {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_sync_implementation() {
let test_struct = TestStruct;
let test = maybe_await!(test_struct.test_method());
assert!(test.is_ok());
}
}
20 changes: 20 additions & 0 deletions dlc-macros/tests/sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use dlc_macros::maybe_async;

/// Documentation
#[maybe_async]
pub trait Example {
/// Documentation
#[maybe_async]
fn example_fn(&self);
}

struct Test;

impl Example for Test {
fn example_fn(&self) {}
}

fn main() {
let test = Test;
test.example_fn();
}
5 changes: 4 additions & 1 deletion dlc-manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ std = ["dlc/std", "dlc-messages/std", "dlc-trie/std", "bitcoin/std", "lightning/
fuzztarget = ["rand_chacha"]
parallel = ["dlc-trie/parallel"]
use-serde = ["serde", "dlc/use-serde", "dlc-messages/use-serde", "dlc-trie/use-serde"]
async = ["dep:async-trait"]

[dependencies]
async-trait = "0.1.50"
async-trait = { version = "0.1.50", optional = true }
bitcoin = { version = "0.32.2", default-features = false }
dlc = { version = "0.6.0", default-features = false, path = "../dlc" }
dlc-messages = { version = "0.6.0", default-features = false, path = "../dlc-messages" }
dlc-trie = { version = "0.6.0", default-features = false, path = "../dlc-trie" }
dlc-macros = { version = "0.1.0", path = "../dlc-macros" }
hex = { package = "hex-conservative", version = "0.1" }
lightning = { version = "0.0.124", default-features = false, features = ["grind_signatures"] }
log = "0.4.14"
rand_chacha = {version = "0.3.1", optional = true}
secp256k1-zkp = {version = "0.11.0"}
serde = {version = "1.0", optional = true}
# bdk-macros = "0.6.0"

[dev-dependencies]
bitcoin-rpc-provider = {path = "../bitcoin-rpc-provider"}
Expand Down
6 changes: 6 additions & 0 deletions dlc-manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#![deny(unused_imports)]
#![deny(missing_docs)]

#[cfg(feature = "async")]
extern crate async_trait;
extern crate bitcoin;
extern crate dlc;
Expand Down Expand Up @@ -44,6 +45,7 @@ use channel::signed_channel::{SignedChannel, SignedChannelStateType};
use channel::Channel;
use contract::PreClosedContract;
use contract::{offered_contract::OfferedContract, signed_contract::SignedContract, Contract};
use dlc_macros::maybe_async;
use dlc_messages::oracle_msgs::{OracleAnnouncement, OracleAttestation};
use dlc_messages::ser_impls::{read_address, write_address};
use error::Error;
Expand Down Expand Up @@ -225,13 +227,17 @@ pub trait Storage {
fn get_chain_monitor(&self) -> Result<Option<ChainMonitor>, Error>;
}

#[allow(missing_docs)]
/// Oracle trait provides access to oracle information.
#[maybe_async]
pub trait Oracle {
/// Returns the public key of the oracle.
fn get_public_key(&self) -> XOnlyPublicKey;
#[maybe_async]
/// Returns the announcement for the event with the given id if found.
fn get_announcement(&self, event_id: &str) -> Result<OracleAnnouncement, Error>;
/// Returns the attestation for the event with the given id if found.
#[maybe_async]
fn get_attestation(&self, event_id: &str) -> Result<OracleAttestation, Error>;
}

Expand Down
Loading

0 comments on commit b970675

Please sign in to comment.