Skip to content

Commit

Permalink
refactor: simplify journal and restore streamer cancelation
Browse files Browse the repository at this point in the history
  • Loading branch information
BorysTheDev committed Feb 14, 2025
1 parent 038b428 commit c6c3c74
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 72 deletions.
4 changes: 2 additions & 2 deletions src/server/cluster/incoming_slot_migration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ClusterShardMigration {
JournalReader reader{source, 0};
TransactionReader tx_reader;

while (!cntx->IsCancelled()) {
while (cntx->IsRunning()) {
if (pause_) {
ThisFiber::SleepFor(100ms);
continue;
Expand Down Expand Up @@ -126,7 +126,7 @@ class ClusterShardMigration {

private:
void ExecuteTx(TransactionData&& tx_data, ExecutionState* cntx) {
if (cntx->IsCancelled()) {
if (!cntx->IsRunning()) {
return;
}
if (!tx_data.IsGlobalCmd()) {
Expand Down
21 changes: 11 additions & 10 deletions src/server/cluster/outgoing_slot_migration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class OutgoingMigration::SliceSlotMigration : private ProtocolClient {
}

void Cancel() {
// We don't care about errors during cancel
cntx_.SwitchErrorHandler([](auto ge) {});
// Close socket for clean disconnect.
CloseSocket();
streamer_.Cancel();
Expand Down Expand Up @@ -194,13 +196,12 @@ void OutgoingMigration::SyncFb() {
break;
}

last_error_ = cntx_.GetError();
cntx_.Reset(nullptr);

if (last_error_) {
LOG(ERROR) << last_error_.Format();
if (cntx_.IsError()) {
last_error_ = cntx_.GetError();
LOG(ERROR) << last_error_;
ThisFiber::SleepFor(1000ms); // wait some time before next retry
}
cntx_.Reset(nullptr);

VLOG(1) << "Connecting to target node";
auto timeout = absl::GetFlag(FLAGS_slot_migration_connection_timeout_ms) * 1ms;
Expand Down Expand Up @@ -246,7 +247,7 @@ void OutgoingMigration::SyncFb() {
}

OnAllShards([this](auto& migration) { migration->PrepareFlow(cf_->MyID()); });
if (cntx_.GetError()) {
if (cntx_.IsError()) {
continue;
}

Expand All @@ -257,13 +258,13 @@ void OutgoingMigration::SyncFb() {
OnAllShards([](auto& migration) { migration->PrepareSync(); });
}

if (cntx_.GetError()) {
if (cntx_.IsError()) {
continue;
}

OnAllShards([](auto& migration) { migration->RunSync(); });

if (cntx_.GetError()) {
if (cntx_.IsError()) {
continue;
}

Expand All @@ -273,7 +274,7 @@ void OutgoingMigration::SyncFb() {
VLOG(1) << "Waiting for migration to finalize...";
ThisFiber::SleepFor(500ms);
}
if (cntx_.GetError()) {
if (cntx_.IsError()) {
continue;
}
break;
Expand All @@ -288,7 +289,7 @@ bool OutgoingMigration::FinalizeMigration(long attempt) {
LOG(INFO) << "Finalize migration for " << cf_->MyID() << " : " << migration_info_.node_info.id
<< " attempt " << attempt;
if (attempt > 1) {
if (cntx_.GetError()) {
if (cntx_.IsError()) {
return true;
}
auto timeout = absl::GetFlag(FLAGS_slot_migration_connection_timeout_ms) * 1ms;
Expand Down
10 changes: 5 additions & 5 deletions src/server/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ GenericError ExecutionState::GetError() const {
return err_;
}

const Cancellation* ExecutionState::GetCancellation() const {
return this;
}
// const Cancellation* ExecutionState::GetCancellation() const {
// return this;
// }

void ExecutionState::ReportCancelError() {
ReportError(std::make_error_code(errc::operation_canceled), "Context cancelled");
Expand All @@ -363,7 +363,7 @@ void ExecutionState::Reset(ErrHandler handler) {
unique_lock lk{err_mu_};
err_ = {};
err_handler_ = std::move(handler);
Cancellation::flag_.store(false, std::memory_order_relaxed);
state_.store(State::RUN, std::memory_order_relaxed);
fb.swap(err_handler_fb_);
lk.unlock();
fb.JoinIfNeeded();
Expand Down Expand Up @@ -402,7 +402,7 @@ GenericError ExecutionState::ReportErrorInternal(GenericError&& err) {
// We can move err_handler_ because it should run at most once.
if (err_handler_)
err_handler_fb_ = fb2::Fiber("report_internal_error", std::move(err_handler_), err_);
Cancellation::Cancel();
state_.store(State::ERROR, std::memory_order_relaxed);
return err_;
}

Expand Down
26 changes: 20 additions & 6 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,34 @@ using AggregateGenericError = AggregateValue<GenericError>;
// handler is run in a separate handler to free up the caller.
//
// ReportCancelError() reporting an `errc::operation_canceled` error.
class ExecutionState : protected Cancellation {
class ExecutionState {
public:
using ErrHandler = std::function<void(const GenericError&)>;

ExecutionState() = default;
ExecutionState(ErrHandler err_handler)
: Cancellation{}, err_{}, err_handler_{std::move(err_handler)} {
ExecutionState(ErrHandler err_handler) : err_handler_{std::move(err_handler)} {
}

~ExecutionState();

// Cancels the context by submitting an `errc::operation_canceled` error.
void ReportCancelError();
using Cancellation::IsCancelled;
const Cancellation* GetCancellation() const;

bool IsRunning() const {
return state_.load(std::memory_order_relaxed) == State::RUN;
}

bool IsError() const {
return state_.load(std::memory_order_relaxed) == State::ERROR;
}

bool IsCancelled() const {
return state_.load(std::memory_order_relaxed) == State::CANCELLED;
}

void Cancel() {
state_.store(State::CANCELLED, std::memory_order_relaxed);
}

GenericError GetError() const;

Expand All @@ -293,9 +306,10 @@ class ExecutionState : protected Cancellation {
void JoinErrorHandler();

private:
// Report error.
GenericError ReportErrorInternal(GenericError&& err);

enum class State { RUN, CANCELLED, ERROR };
std::atomic<State> state_{State::RUN};
GenericError err_;
ErrHandler err_handler_;
util::fb2::Fiber err_handler_fb_;
Expand Down
2 changes: 1 addition & 1 deletion src/server/dflycmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ bool WaitReplicaFlowToCatchup(absl::Time end_time, const DflyCmd::ReplicaInfo* r
<< ", expecting " << shard->journal()->GetLsn();
return false;
}
if (replica->cntx.IsCancelled()) {
if (!replica->cntx.IsRunning()) {
return false;
}
VLOG(1) << "Replica lsn:" << flow->last_acked_lsn
Expand Down
33 changes: 17 additions & 16 deletions src/server/journal/streamer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ JournalStreamer::JournalStreamer(journal::Journal* journal, ExecutionState* cntx
}

JournalStreamer::~JournalStreamer() {
if (!cntx_->IsCancelled()) {
if (!cntx_->IsError()) {
DCHECK_EQ(in_flight_bytes_, 0u);
}
VLOG(1) << "~JournalStreamer";
Expand Down Expand Up @@ -83,7 +83,7 @@ void JournalStreamer::Cancel() {
VLOG(1) << "JournalStreamer::Cancel";
waker_.notifyAll();
journal_->UnregisterOnChange(journal_cb_id_);
if (!cntx_->IsCancelled()) {
if (!cntx_->IsError()) {
WaitForInflightToComplete();
}
}
Expand Down Expand Up @@ -134,10 +134,12 @@ void JournalStreamer::OnCompletion(std::error_code ec, size_t len) {
DVLOG(3) << "Completing " << in_flight_bytes_;
in_flight_bytes_ = 0;
pending_buf_.Pop();
if (ec && !IsStopped()) {
cntx_->ReportError(ec);
} else if (!pending_buf_.Empty() && !IsStopped()) {
AsyncWrite();
if (cntx_->IsRunning()) {
if (ec) {
cntx_->ReportError(ec);
} else if (!pending_buf_.Empty()) {
AsyncWrite();
}
}

// notify ThrottleIfNeeded or WaitForInflightToComplete that waits
Expand All @@ -149,7 +151,7 @@ void JournalStreamer::OnCompletion(std::error_code ec, size_t len) {
}

void JournalStreamer::ThrottleIfNeeded() {
if (IsStopped() || !IsStalled())
if (!cntx_->IsRunning() || !IsStalled())
return;

auto next =
Expand All @@ -158,7 +160,7 @@ void JournalStreamer::ThrottleIfNeeded() {
size_t sent_start = total_sent_;

std::cv_status status =
waker_.await_until([this]() { return !IsStalled() || IsStopped(); }, next);
waker_.await_until([this]() { return !IsStalled() || !cntx_->IsRunning(); }, next);
if (status == std::cv_status::timeout) {
LOG(WARNING) << "Stream timed out, inflight bytes/sent start: " << inflight_start << "/"
<< sent_start << ", end: " << in_flight_bytes_ << "/" << total_sent_;
Expand Down Expand Up @@ -188,7 +190,7 @@ RestoreStreamer::RestoreStreamer(DbSlice* slice, cluster::SlotSet slots, journal
}

void RestoreStreamer::Start(util::FiberSocketBase* dest, bool send_lsn) {
if (fiber_cancelled_)
if (!cntx_->IsRunning())
return;

VLOG(1) << "RestoreStreamer start";
Expand All @@ -206,16 +208,16 @@ void RestoreStreamer::Run() {
PrimeTable* pt = &db_array_[0]->prime;

do {
if (fiber_cancelled_)
if (!cntx_->IsRunning())
return;
cursor = pt->TraverseBuckets(cursor, [&](PrimeTable::bucket_iterator it) {
if (fiber_cancelled_) // Could be cancelled any time as Traverse may preempt
if (!cntx_->IsRunning()) // Could be cancelled any time as Traverse may preempt
return;

db_slice_->FlushChangeToEarlierCallbacks(0 /*db_id always 0 for cluster*/,
DbSlice::Iterator::FromPrime(it), snapshot_version_);

if (fiber_cancelled_) // Could have been cancelled in above call too
if (!cntx_->IsRunning()) // Could have been cancelled in above call too
return;

std::lock_guard guard(big_value_mu_);
Expand All @@ -231,7 +233,7 @@ void RestoreStreamer::Run() {
ThisFiber::Yield();
last_yield = 0;
}
} while (cursor && !fiber_cancelled_);
} while (cursor);

VLOG(1) << "RestoreStreamer finished loop of " << my_slots_.ToSlotRanges().ToString()
<< ", shard " << db_slice_->shard_id() << ". Buckets looped " << stats_.buckets_loop;
Expand All @@ -252,8 +254,7 @@ void RestoreStreamer::SendFinalize(long attempt) {
writer.Write(entry);
Write(std::move(sink).str());

// TODO: is the intent here to flush everything?
//
// DFLYMIGRATE ACK command has a timeout so we want to send it only when LSN is ready to be sent
ThrottleIfNeeded();
}

Expand All @@ -263,7 +264,7 @@ RestoreStreamer::~RestoreStreamer() {
void RestoreStreamer::Cancel() {
auto sver = snapshot_version_;
snapshot_version_ = 0; // to prevent double cancel in another fiber
fiber_cancelled_ = true;
cntx_->Cancel();
if (sver != 0) {
db_slice_->UnregisterOnChange(sver);
JournalStreamer::Cancel();
Expand Down
13 changes: 2 additions & 11 deletions src/server/journal/streamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class JournalStreamer {
void ThrottleIfNeeded();

virtual bool ShouldWrite(const journal::JournalItem& item) const {
return !IsStopped();
return cntx_->IsRunning();
}

void WaitForInflightToComplete();
Expand All @@ -59,10 +59,6 @@ class JournalStreamer {
void AsyncWrite();
void OnCompletion(std::error_code ec, size_t len);

bool IsStopped() const {
return cntx_->IsCancelled();
}

bool IsStalled() const;

journal::Journal* journal_;
Expand Down Expand Up @@ -92,10 +88,6 @@ class RestoreStreamer : public JournalStreamer {

void SendFinalize(long attempt);

bool IsSnapshotFinished() const {
return snapshot_finished_;
}

private:
void OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req);
bool ShouldWrite(const journal::JournalItem& item) const override;
Expand All @@ -122,8 +114,7 @@ class RestoreStreamer : public JournalStreamer {
DbTableArray db_array_;
uint64_t snapshot_version_ = 0;
cluster::SlotSet my_slots_;
bool fiber_cancelled_ = false;
bool snapshot_finished_ = false;

ThreadLocalMutex big_value_mu_;
Stats stats_;
};
Expand Down
2 changes: 1 addition & 1 deletion src/server/protocol_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ error_code ProtocolClient::ConnectAndAuth(std::chrono::milliseconds connect_time
// The context closes sock_. So if the context error handler has already
// run we must not create a new socket. sock_mu_ syncs between the two
// functions.
if (!cntx->IsCancelled()) {
if (cntx->IsRunning()) {
if (sock_) {
LOG_IF(WARNING, sock_->Close()) << "Error closing socket";
sock_.reset(nullptr);
Expand Down
12 changes: 6 additions & 6 deletions src/server/rdb_save.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ class RdbSaver::Impl final : public SliceSnapshot::SnapshotDataConsumerInterface
void Finalize() override;

// used only for legacy rdb save flows.
error_code ConsumeChannel(const Cancellation* cll);
error_code ConsumeChannel(const ExecutionState* cll);

void FillFreqMap(RdbTypeFreqMap* dest) const;

Expand Down Expand Up @@ -1161,7 +1161,7 @@ error_code RdbSaver::Impl::SaveAuxFieldStrStr(string_view key, string_view val)
return error_code{};
}

error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
error_code RdbSaver::Impl::ConsumeChannel(const ExecutionState* es) {
error_code io_error;
string record;

Expand All @@ -1170,11 +1170,11 @@ error_code RdbSaver::Impl::ConsumeChannel(const Cancellation* cll) {
// we can not exit on io-error since we spawn fibers that push data.
// TODO: we may signal them to stop processing and exit asap in case of the error.
while (channel_->Pop(record)) {
if (io_error || cll->IsCancelled())
if (io_error || (!es->IsRunning()))
continue;

do {
if (cll->IsCancelled())
if (!es->IsRunning())
continue;

auto start = absl::GetCurrentTimeNanos();
Expand Down Expand Up @@ -1258,7 +1258,7 @@ void RdbSaver::Impl::WaitForSnapshottingFinish(EngineShard* shard) {
}

void RdbSaver::Impl::ConsumeData(std::string data, ExecutionState* cntx) {
if (cntx->IsCancelled()) {
if (!cntx->IsRunning()) {
return;
}
if (channel_) { // Rdb write to channel
Expand Down Expand Up @@ -1468,7 +1468,7 @@ error_code RdbSaver::SaveBody(const ExecutionState& cntx) {

if (save_mode_ == SaveMode::RDB) {
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
error_code io_error = impl_->ConsumeChannel(cntx.GetCancellation());
error_code io_error = impl_->ConsumeChannel(&cntx);
if (io_error) {
return io_error;
}
Expand Down
Loading

0 comments on commit c6c3c74

Please sign in to comment.