diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b5059b2..459579c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -11,6 +11,7 @@ on: jobs: fmt: runs-on: windows-latest + timeout-minutes: 10 steps: - uses: actions/checkout@v4 - name: Setup dotnet @@ -26,6 +27,7 @@ jobs: test: runs-on: windows-latest + timeout-minutes: 10 steps: - uses: actions/checkout@v4 - name: Setup dotnet @@ -34,13 +36,36 @@ jobs: dotnet-version: 8.0.x cache: true cache-dependency-path: '**/packages.lock.json' + - name: Install Windows App SDK Runtime + shell: pwsh + run: | + $ErrorActionPreference = "Stop" + + $filename = ".\WindowsAppRuntimeInstall-x64.exe" + $url = "https://download.microsoft.com/download/7a3a6a44-b07e-4ca5-8b63-2de185769dbc/WindowsAppRuntimeInstall-x64.exe" # 1.6.5 (1.6.250205002) + & curl.exe --progress-bar --show-error --fail --location --output $filename $url + if ($LASTEXITCODE -ne 0) { throw "Failed to download Windows App SDK" } + + $process = Start-Process -FilePath $filename -ArgumentList "--quiet --force" -NoNewWindow -Wait -PassThru + if ($process.ExitCode -ne 0) { throw "Failed to install Windows App SDK: exit code is $($process.ExitCode)" } - name: dotnet restore run: dotnet restore --locked-mode - name: dotnet test - run: dotnet test --no-restore + run: dotnet test --no-restore --blame-hang --blame-hang-dump-type full --blame-hang-timeout 2m -p:Platform=x64 + - name: Upload test binaries and TestResults + if: failure() + uses: actions/upload-artifact@v4 + with: + name: test-results + retention-days: 1 + path: | + ./**/bin + ./**/obj + ./**/TestResults build: runs-on: windows-latest + timeout-minutes: 10 steps: - uses: actions/checkout@v4 - name: Setup dotnet diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4be431b..e6849aa 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -18,6 +18,7 @@ permissions: jobs: release: runs-on: ${{ github.repository_owner == 'coder' && 'windows-latest-16-cores' || 'windows-latest' }} + timeout-minutes: 15 steps: - uses: actions/checkout@v4 diff --git a/App/App.xaml.cs b/App/App.xaml.cs index af4217e..9895fc8 100644 --- a/App/App.xaml.cs +++ b/App/App.xaml.cs @@ -1,5 +1,7 @@ using System; +using System.Threading; using System.Threading.Tasks; +using Coder.Desktop.App.Models; using Coder.Desktop.App.Services; using Coder.Desktop.App.ViewModels; using Coder.Desktop.App.Views; @@ -26,6 +28,7 @@ public App() services.AddTransient(); // TrayWindow views and view models + services.AddTransient(); services.AddTransient(); services.AddTransient(); services.AddTransient(); @@ -45,17 +48,29 @@ public async Task ExitApplication() { _handleWindowClosed = false; Exit(); - var rpcManager = _services.GetRequiredService(); + var rpcController = _services.GetRequiredService(); // TODO: send a StopRequest if we're connected??? - await rpcManager.DisposeAsync(); + await rpcController.DisposeAsync(); Environment.Exit(0); } protected override void OnLaunched(LaunchActivatedEventArgs args) { - var trayWindow = _services.GetRequiredService(); + // Start connecting to the manager in the background. + var rpcController = _services.GetRequiredService(); + if (rpcController.GetState().RpcLifecycle == RpcLifecycle.Disconnected) + // Passing in a CT with no cancellation is desired here, because + // the named pipe open will block until the pipe comes up. + _ = rpcController.Reconnect(CancellationToken.None); + + // Load the credentials in the background. Even though we pass a CT + // with no cancellation, the method itself will impose a timeout on the + // HTTP portion. + var credentialManager = _services.GetRequiredService(); + _ = credentialManager.LoadCredentials(CancellationToken.None); // Prevent the TrayWindow from closing, just hide it. + var trayWindow = _services.GetRequiredService(); trayWindow.Closed += (sender, args) => { if (!_handleWindowClosed) return; diff --git a/App/Models/CredentialModel.cs b/App/Models/CredentialModel.cs index 5388722..542c1c0 100644 --- a/App/Models/CredentialModel.cs +++ b/App/Models/CredentialModel.cs @@ -2,16 +2,24 @@ namespace Coder.Desktop.App.Models; public enum CredentialState { + // Unknown means "we haven't checked yet" + Unknown, + + // Invalid means "we checked and there's either no saved credentials or they are not valid" Invalid, + + // Valid means "we checked and there are saved credentials and they are valid" Valid, } public class CredentialModel { - public CredentialState State { get; set; } = CredentialState.Invalid; + public CredentialState State { get; init; } = CredentialState.Unknown; + + public string? CoderUrl { get; init; } + public string? ApiToken { get; init; } - public string? CoderUrl { get; set; } - public string? ApiToken { get; set; } + public string? Username { get; init; } public CredentialModel Clone() { @@ -20,6 +28,7 @@ public CredentialModel Clone() State = State, CoderUrl = CoderUrl, ApiToken = ApiToken, + Username = Username, }; } } diff --git a/App/Services/CredentialManager.cs b/App/Services/CredentialManager.cs index a3456b7..41a8dc7 100644 --- a/App/Services/CredentialManager.cs +++ b/App/Services/CredentialManager.cs @@ -6,8 +6,8 @@ using System.Threading; using System.Threading.Tasks; using Coder.Desktop.App.Models; +using Coder.Desktop.CoderSdk; using Coder.Desktop.Vpn.Utilities; -using CoderSdk; namespace Coder.Desktop.App.Services; @@ -18,119 +18,296 @@ public class RawCredentials } [JsonSerializable(typeof(RawCredentials))] -public partial class RawCredentialsJsonContext : JsonSerializerContext -{ -} +public partial class RawCredentialsJsonContext : JsonSerializerContext; public interface ICredentialManager { public event EventHandler CredentialsChanged; - public CredentialModel GetCredentials(); + /// + /// Returns cached credentials or an invalid credential model if none are cached. It's preferable to use + /// LoadCredentials if you are operating in an async context. + /// + public CredentialModel GetCachedCredentials(); + + /// + /// Get any sign-in URL. The returned value is not parsed to check if it's a valid URI. + /// + public Task GetSignInUri(); + + /// + /// Returns cached credentials or loads/verifies them from storage if not cached. + /// + public Task LoadCredentials(CancellationToken ct = default); public Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct = default); - public void ClearCredentials(); + public Task ClearCredentials(CancellationToken ct = default); +} + +public interface ICredentialBackend +{ + public Task ReadCredentials(CancellationToken ct = default); + public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default); + public Task DeleteCredentials(CancellationToken ct = default); } +/// +/// Implements ICredentialManager using an ICredentialBackend to store +/// credentials. +/// public class CredentialManager : ICredentialManager { private const string CredentialsTargetName = "Coder.Desktop.App.Credentials"; - private readonly RaiiSemaphoreSlim _lock = new(1, 1); - private CredentialModel? _latestCredentials; + // _opLock is held for the full duration of SetCredentials, and partially + // during LoadCredentials. _opLock protects _inFlightLoad, _loadCts, and + // writes to _latestCredentials. + private readonly RaiiSemaphoreSlim _opLock = new(1, 1); + + // _inFlightLoad and _loadCts are set at the beginning of a LoadCredentials + // call. + private Task? _inFlightLoad; + private CancellationTokenSource? _loadCts; + + // Reading and writing a reference in C# is always atomic, so this doesn't + // need to be protected on reads with a lock in GetCachedCredentials. + // + // The volatile keyword disables optimizations on reads/writes which helps + // other threads see the new value quickly (no guarantee that it's + // immediate). + private volatile CredentialModel? _latestCredentials; + + private ICredentialBackend Backend { get; } = new WindowsCredentialBackend(CredentialsTargetName); + + private ICoderApiClientFactory CoderApiClientFactory { get; } = new CoderApiClientFactory(); + + public CredentialManager() + { + } + + public CredentialManager(ICredentialBackend backend, ICoderApiClientFactory coderApiClientFactory) + { + Backend = backend; + CoderApiClientFactory = coderApiClientFactory; + } public event EventHandler? CredentialsChanged; - public CredentialModel GetCredentials() + public CredentialModel GetCachedCredentials() { - using var _ = _lock.Lock(); - if (_latestCredentials != null) return _latestCredentials.Clone(); + // No lock required to read the reference. + var latestCreds = _latestCredentials; + // No clone needed as the model is immutable. + if (latestCreds != null) return latestCreds; - var rawCredentials = ReadCredentials(); - if (rawCredentials is null) - _latestCredentials = new CredentialModel - { - State = CredentialState.Invalid, - }; - else - _latestCredentials = new CredentialModel - { - State = CredentialState.Valid, - CoderUrl = rawCredentials.CoderUrl, - ApiToken = rawCredentials.ApiToken, - }; - return _latestCredentials.Clone(); + return new CredentialModel + { + State = CredentialState.Unknown, + }; } - public async Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct = default) + public async Task GetSignInUri() { + try + { + var raw = await Backend.ReadCredentials(); + if (raw is not null && !string.IsNullOrWhiteSpace(raw.CoderUrl)) return raw.CoderUrl; + } + catch + { + // ignored + } + + return null; + } + + // LoadCredentials may be preempted by SetCredentials. + public Task LoadCredentials(CancellationToken ct = default) + { + // This function is not `async` because we may return an existing task. + // However, we still want to acquire the lock with the + // CancellationToken so it can be canceled if needed. + using var _ = _opLock.LockAsync(ct).Result; + + // If we already have a cached value, return it. + var latestCreds = _latestCredentials; + if (latestCreds != null) return Task.FromResult(latestCreds); + + // If we are already loading, return the existing task. + if (_inFlightLoad != null) return _inFlightLoad; + + // Otherwise, kick off a new load. + // Note: subsequent loads returned from above will ignore the passed in + // CancellationToken. We set a maximum timeout of 15 seconds anyway. + _loadCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + _loadCts.CancelAfter(TimeSpan.FromSeconds(15)); + _inFlightLoad = LoadCredentialsInner(_loadCts.Token); + return _inFlightLoad; + } + + public async Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct) + { + using var _ = await _opLock.LockAsync(ct); + + // If there's an ongoing load, cancel it. + if (_loadCts != null) + { + await _loadCts.CancelAsync(); + _loadCts.Dispose(); + _loadCts = null; + _inFlightLoad = null; + } + if (string.IsNullOrWhiteSpace(coderUrl)) throw new ArgumentException("Coder URL is required", nameof(coderUrl)); coderUrl = coderUrl.Trim(); - if (coderUrl.Length > 128) throw new ArgumentOutOfRangeException(nameof(coderUrl), "Coder URL is too long"); + if (coderUrl.Length > 128) throw new ArgumentException("Coder URL is too long", nameof(coderUrl)); if (!Uri.TryCreate(coderUrl, UriKind.Absolute, out var uri)) throw new ArgumentException($"Coder URL '{coderUrl}' is not a valid URL", nameof(coderUrl)); + if (uri.Scheme != "http" && uri.Scheme != "https") + throw new ArgumentException("Coder URL must be HTTP or HTTPS", nameof(coderUrl)); if (uri.PathAndQuery != "/") throw new ArgumentException("Coder URL must be the root URL", nameof(coderUrl)); if (string.IsNullOrWhiteSpace(apiToken)) throw new ArgumentException("API token is required", nameof(apiToken)); apiToken = apiToken.Trim(); - if (apiToken.Length != 33) - throw new ArgumentOutOfRangeException(nameof(apiToken), "API token must be 33 characters long"); - - try - { - var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); - cts.CancelAfter(TimeSpan.FromSeconds(15)); - var sdkClient = new CoderApiClient(uri); - sdkClient.SetSessionToken(apiToken); - // TODO: we should probably perform a version check here too, - // rather than letting the service do it on Start - _ = await sdkClient.GetBuildInfo(cts.Token); - _ = await sdkClient.GetUser(User.Me, cts.Token); - } - catch (Exception e) - { - throw new InvalidOperationException("Could not connect to or verify Coder server", e); - } - WriteCredentials(new RawCredentials + var raw = new RawCredentials { CoderUrl = coderUrl, ApiToken = apiToken, - }); + }; + var populateCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + populateCts.CancelAfter(TimeSpan.FromSeconds(15)); + var model = await PopulateModel(raw, populateCts.Token); + await Backend.WriteCredentials(raw, ct); + UpdateState(model); + } + public async Task ClearCredentials(CancellationToken ct = default) + { + using var _ = await _opLock.LockAsync(ct); + await Backend.DeleteCredentials(ct); UpdateState(new CredentialModel { - State = CredentialState.Valid, - CoderUrl = coderUrl, - ApiToken = apiToken, + State = CredentialState.Invalid, }); } - public void ClearCredentials() + private async Task LoadCredentialsInner(CancellationToken ct) { - NativeApi.DeleteCredentials(CredentialsTargetName); - UpdateState(new CredentialModel + CredentialModel model; + try { - State = CredentialState.Invalid, - CoderUrl = null, - ApiToken = null, - }); + var raw = await Backend.ReadCredentials(ct); + model = await PopulateModel(raw, ct); + } + catch + { + // This catch will be hit if a SetCredentials operation started, or + // if the read/populate failed for some other reason (e.g. HTTP + // timeout). + // + // We don't need to clear the credentials here, the app will think + // they're unset and any subsequent SetCredentials call after the + // user signs in again will overwrite the old invalid ones. + model = new CredentialModel + { + State = CredentialState.Invalid, + }; + } + + // Grab the lock again so we can update the state. + using (await _opLock.LockAsync(ct)) + { + // Prevent new LoadCredentials calls from returning this task. + if (_loadCts != null) + { + _loadCts.Dispose(); + _loadCts = null; + _inFlightLoad = null; + } + + // If we were canceled but made it this far, try to return the + // latest credentials instead. + if (ct.IsCancellationRequested) + { + var latestCreds = _latestCredentials; + if (latestCreds is not null) return latestCreds; + } + + // If there aren't any latest credentials after a cancellation, we + // most likely timed out and should throw. + ct.ThrowIfCancellationRequested(); + + UpdateState(model); + return model; + } } - private void UpdateState(CredentialModel newModel) + private async Task PopulateModel(RawCredentials? credentials, CancellationToken ct) { - using (_lock.Lock()) + if (credentials is null || string.IsNullOrWhiteSpace(credentials.CoderUrl) || + string.IsNullOrWhiteSpace(credentials.ApiToken)) + return new CredentialModel + { + State = CredentialState.Invalid, + }; + + BuildInfo buildInfo; + User me; + try + { + var sdkClient = CoderApiClientFactory.Create(credentials.CoderUrl); + // BuildInfo does not require authentication. + buildInfo = await sdkClient.GetBuildInfo(ct); + sdkClient.SetSessionToken(credentials.ApiToken); + me = await sdkClient.GetUser(User.Me, ct); + } + catch (CoderApiHttpException) { - _latestCredentials = newModel.Clone(); + throw; } + catch (Exception e) + { + throw new InvalidOperationException("Could not connect to or verify Coder server", e); + } + + ServerVersionUtilities.ParseAndValidateServerVersion(buildInfo.Version); + if (string.IsNullOrWhiteSpace(me.Username)) + throw new InvalidOperationException("Could not retrieve user information, username is empty"); + + return new CredentialModel + { + State = CredentialState.Valid, + CoderUrl = credentials.CoderUrl, + ApiToken = credentials.ApiToken, + Username = me.Username, + }; + } - CredentialsChanged?.Invoke(this, newModel.Clone()); + // Lock must be held when calling this function. + private void UpdateState(CredentialModel newModel) + { + _latestCredentials = newModel; + // Since the event handlers could block (or call back the + // CredentialManager and deadlock), we run these in a new task. + if (CredentialsChanged == null) return; + Task.Run(() => { CredentialsChanged?.Invoke(this, newModel); }); } +} - private static RawCredentials? ReadCredentials() +public class WindowsCredentialBackend : ICredentialBackend +{ + private readonly string _credentialsTargetName; + + public WindowsCredentialBackend(string credentialsTargetName) { - var raw = NativeApi.ReadCredentials(CredentialsTargetName); - if (raw == null) return null; + _credentialsTargetName = credentialsTargetName; + } + + public Task ReadCredentials(CancellationToken ct = default) + { + var raw = NativeApi.ReadCredentials(_credentialsTargetName); + if (raw == null) return Task.FromResult(null); RawCredentials? credentials; try @@ -139,19 +316,23 @@ private void UpdateState(CredentialModel newModel) } catch (JsonException) { - return null; + credentials = null; } - if (credentials is null || string.IsNullOrWhiteSpace(credentials.CoderUrl) || - string.IsNullOrWhiteSpace(credentials.ApiToken)) return null; - - return credentials; + return Task.FromResult(credentials); } - private static void WriteCredentials(RawCredentials credentials) + public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default) { var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials); - NativeApi.WriteCredentials(CredentialsTargetName, raw); + NativeApi.WriteCredentials(_credentialsTargetName, raw); + return Task.CompletedTask; + } + + public Task DeleteCredentials(CancellationToken ct = default) + { + NativeApi.DeleteCredentials(_credentialsTargetName); + return Task.CompletedTask; } private static class NativeApi diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index 1b3dac6..248a011 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -156,9 +156,10 @@ public async Task StartVpn(CancellationToken ct = default) using var _ = await AcquireOperationLockNowAsync(); AssertRpcConnected(); - var credentials = _credentialManager.GetCredentials(); + var credentials = _credentialManager.GetCachedCredentials(); if (credentials.State != CredentialState.Valid) - throw new RpcOperationException("Cannot start VPN without valid credentials"); + throw new RpcOperationException( + $"Cannot start VPN without valid credentials, current state: {credentials.State}"); MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; }); diff --git a/App/ViewModels/SignInViewModel.cs b/App/ViewModels/SignInViewModel.cs index ae64f2b..fcd47d4 100644 --- a/App/ViewModels/SignInViewModel.cs +++ b/App/ViewModels/SignInViewModel.cs @@ -6,6 +6,7 @@ using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; using Microsoft.UI.Xaml; +using Microsoft.UI.Xaml.Controls; namespace Coder.Desktop.App.ViewModels; @@ -33,8 +34,6 @@ public partial class SignInViewModel : ObservableObject [NotifyPropertyChangedFor(nameof(ApiTokenError))] public partial bool ApiTokenTouched { get; set; } = false; - [ObservableProperty] public partial string? SignInError { get; set; } = null; - [ObservableProperty] public partial bool SignInLoading { get; set; } = false; public string? CoderUrlError => CoderUrlTouched ? _coderUrlError : null; @@ -82,6 +81,29 @@ public SignInViewModel(ICredentialManager credentialManager) _credentialManager = credentialManager; } + // When the URL box loads, get the old URI from the credential manager. + // This is an async operation on paper, but we would expect it to be + // synchronous or extremely quick in practice. + public void CoderUrl_Loaded(object sender, RoutedEventArgs e) + { + if (sender is not TextBox textBox) return; + + var dispatcherQueue = textBox.DispatcherQueue; + _credentialManager.GetSignInUri().ContinueWith(t => + { + if (t.IsCompleted && !string.IsNullOrWhiteSpace(t.Result)) + dispatcherQueue.TryEnqueue(() => + { + if (!CoderUrlTouched) + { + CoderUrl = t.Result; + CoderUrlTouched = true; + textBox.SelectionStart = CoderUrl.Length; + } + }); + }); + } + public void CoderUrl_FocusLost(object sender, RoutedEventArgs e) { CoderUrlTouched = true; @@ -117,7 +139,6 @@ public async Task TokenPage_SignIn(SignInWindow signInWindow) try { SignInLoading = true; - SignInError = null; var cts = new CancellationTokenSource(TimeSpan.FromSeconds(15)); await _credentialManager.SetCredentials(CoderUrl.Trim(), ApiToken.Trim(), cts.Token); @@ -126,7 +147,14 @@ public async Task TokenPage_SignIn(SignInWindow signInWindow) } catch (Exception e) { - SignInError = $"Failed to sign in: {e}"; + var dialog = new ContentDialog + { + Title = "Failed to sign in", + Content = $"{e}", + CloseButtonText = "Ok", + XamlRoot = signInWindow.Content.XamlRoot, + }; + _ = await dialog.ShowAsync(); } finally { diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index 204d9f0..7d122df 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -87,7 +87,7 @@ public void Initialize(DispatcherQueue dispatcherQueue) UpdateFromRpcModel(_rpcController.GetState()); _credentialManager.CredentialsChanged += (_, credentialModel) => UpdateFromCredentialsModel(credentialModel); - UpdateFromCredentialsModel(_credentialManager.GetCredentials()); + UpdateFromCredentialsModel(_credentialManager.GetCachedCredentials()); } private void UpdateFromRpcModel(RpcModel rpcModel) @@ -114,7 +114,7 @@ private void UpdateFromRpcModel(RpcModel rpcModel) VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started; // Get the current dashboard URL. - var credentialModel = _credentialManager.GetCredentials(); + var credentialModel = _credentialManager.GetCachedCredentials(); Uri? coderUri = null; if (credentialModel.State == CredentialState.Valid && !string.IsNullOrWhiteSpace(credentialModel.CoderUrl)) try diff --git a/App/Views/Pages/SignInTokenPage.xaml b/App/Views/Pages/SignInTokenPage.xaml index 93a1796..8613f19 100644 --- a/App/Views/Pages/SignInTokenPage.xaml +++ b/App/Views/Pages/SignInTokenPage.xaml @@ -95,10 +95,5 @@ Command="{x:Bind ViewModel.TokenPage_SignInCommand, Mode=OneWay}" CommandParameter="{x:Bind SignInWindow, Mode=OneWay}" /> - - diff --git a/App/Views/Pages/SignInUrlPage.xaml b/App/Views/Pages/SignInUrlPage.xaml index 1c12b03..76f6a3a 100644 --- a/App/Views/Pages/SignInUrlPage.xaml +++ b/App/Views/Pages/SignInUrlPage.xaml @@ -46,6 +46,7 @@ Grid.Row="0" HorizontalAlignment="Stretch" PlaceholderText="https://coder.example.com" + Loaded="{x:Bind ViewModel.CoderUrl_Loaded, Mode=OneWay}" LostFocus="{x:Bind ViewModel.CoderUrl_FocusLost, Mode=OneWay}" Text="{x:Bind ViewModel.CoderUrl, Mode=TwoWay}" /> diff --git a/App/Views/Pages/TrayWindowLoadingPage.xaml b/App/Views/Pages/TrayWindowLoadingPage.xaml new file mode 100644 index 0000000..6e103ad --- /dev/null +++ b/App/Views/Pages/TrayWindowLoadingPage.xaml @@ -0,0 +1,26 @@ + + + + + + + + + + diff --git a/App/Views/Pages/TrayWindowLoadingPage.xaml.cs b/App/Views/Pages/TrayWindowLoadingPage.xaml.cs new file mode 100644 index 0000000..9b207a7 --- /dev/null +++ b/App/Views/Pages/TrayWindowLoadingPage.xaml.cs @@ -0,0 +1,11 @@ +using Microsoft.UI.Xaml.Controls; + +namespace Coder.Desktop.App.Views.Pages; + +public sealed partial class TrayWindowLoadingPage : Page +{ + public TrayWindowLoadingPage() + { + InitializeComponent(); + } +} diff --git a/App/Views/TrayWindow.xaml.cs b/App/Views/TrayWindow.xaml.cs index b528723..7fd1482 100644 --- a/App/Views/TrayWindow.xaml.cs +++ b/App/Views/TrayWindow.xaml.cs @@ -1,6 +1,5 @@ using System; using System.Runtime.InteropServices; -using System.Threading; using Windows.Foundation; using Windows.Graphics; using Windows.System; @@ -28,16 +27,19 @@ public sealed partial class TrayWindow : Window private readonly IRpcController _rpcController; private readonly ICredentialManager _credentialManager; + private readonly TrayWindowLoadingPage _loadingPage; private readonly TrayWindowDisconnectedPage _disconnectedPage; private readonly TrayWindowLoginRequiredPage _loginRequiredPage; private readonly TrayWindowMainPage _mainPage; public TrayWindow(IRpcController rpcController, ICredentialManager credentialManager, + TrayWindowLoadingPage loadingPage, TrayWindowDisconnectedPage disconnectedPage, TrayWindowLoginRequiredPage loginRequiredPage, TrayWindowMainPage mainPage) { _rpcController = rpcController; _credentialManager = credentialManager; + _loadingPage = loadingPage; _disconnectedPage = disconnectedPage; _loginRequiredPage = loginRequiredPage; _mainPage = mainPage; @@ -49,9 +51,7 @@ public TrayWindow(IRpcController rpcController, ICredentialManager credentialMan rpcController.StateChanged += RpcController_StateChanged; credentialManager.CredentialsChanged += CredentialManager_CredentialsChanged; - SetPageByState(rpcController.GetState(), credentialManager.GetCredentials()); - - _rpcController.Reconnect(CancellationToken.None); + SetPageByState(rpcController.GetState(), credentialManager.GetCachedCredentials()); // Setting OpenCommand and ExitCommand directly in the .xaml doesn't seem to work for whatever reason. TrayIcon.OpenCommand = Tray_OpenCommand; @@ -78,6 +78,12 @@ public TrayWindow(IRpcController rpcController, ICredentialManager credentialMan private void SetPageByState(RpcModel rpcModel, CredentialModel credentialModel) { + if (credentialModel.State == CredentialState.Unknown) + { + SetRootFrame(_loadingPage); + return; + } + switch (rpcModel.RpcLifecycle) { case RpcLifecycle.Connected: @@ -96,7 +102,7 @@ private void SetPageByState(RpcModel rpcModel, CredentialModel credentialModel) private void RpcController_StateChanged(object? _, RpcModel model) { - SetPageByState(model, _credentialManager.GetCredentials()); + SetPageByState(model, _credentialManager.GetCachedCredentials()); } private void CredentialManager_CredentialsChanged(object? _, CredentialModel model) diff --git a/App/packages.lock.json b/App/packages.lock.json index e547ab4..264df38 100644 --- a/App/packages.lock.json +++ b/App/packages.lock.json @@ -71,6 +71,11 @@ "resolved": "9.0.1", "contentHash": "Tr74eP0oQ3AyC24ch17N8PuEkrPbD0JqIfENCYqmgKYNOmL8wQKzLJu3ObxTUDrjnn4rHoR1qKa37/eQyHmCDA==" }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "5.0.1", + "contentHash": "5WPSmL4YeP7eW+Vc8XZ4DwjYWBAiSwDV9Hm63JJWcz1Ie3Xjv4KuJXzgCstj48LkLfVCYa7mLcx7y+q6yqVvtw==" + }, "Microsoft.Web.WebView2": { "type": "Transitive", "resolved": "1.0.2651.64", @@ -86,6 +91,14 @@ "resolved": "10.0.22621.756", "contentHash": "7ZL2sFSioYm1Ry067Kw1hg0SCcW5kuVezC2SwjGbcPE61Nn+gTbH86T73G3LcEOVj0S3IZzNuE/29gZvOLS7VA==" }, + "Semver": { + "type": "Transitive", + "resolved": "3.0.0", + "contentHash": "9jZCicsVgTebqkAujRWtC9J1A5EQVlu0TVKHcgoCuv345ve5DYf4D1MjhKEnQjdRZo6x/vdv6QQrYFs7ilGzLA==", + "dependencies": { + "Microsoft.Extensions.Primitives": "5.0.1" + } + }, "System.Collections.Immutable": { "type": "Transitive", "resolved": "9.0.0", @@ -119,6 +132,7 @@ "type": "Project", "dependencies": { "Coder.Desktop.Vpn.Proto": "[1.0.0, )", + "Semver": "[3.0.0, )", "System.IO.Pipelines": "[9.0.1, )" } }, diff --git a/Coder.Desktop.sln b/Coder.Desktop.sln index 7d85caa..0a20185 100644 --- a/Coder.Desktop.sln +++ b/Coder.Desktop.sln @@ -23,6 +23,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Vpn.DebugClient", "Vpn.Debu EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Installer", "Installer\Installer.csproj", "{39F5B55A-09D8-477D-A3FA-ADAC29C52605}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.App", "Tests.App\Tests.App.csproj", "{3E91CED7-5528-4B46-8722-FB95D4FAB967}" +EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MutagenSdk", "MutagenSdk\MutagenSdk.csproj", "{E2477ADC-03DA-490D-9369-79A4CC4A58D2}" EndProject Global @@ -205,6 +207,22 @@ Global {39F5B55A-09D8-477D-A3FA-ADAC29C52605}.Release|x64.Build.0 = Release|Any CPU {39F5B55A-09D8-477D-A3FA-ADAC29C52605}.Release|x86.ActiveCfg = Release|Any CPU {39F5B55A-09D8-477D-A3FA-ADAC29C52605}.Release|x86.Build.0 = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|ARM64.ActiveCfg = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|ARM64.Build.0 = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|x64.ActiveCfg = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|x64.Build.0 = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|x86.ActiveCfg = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Debug|x86.Build.0 = Debug|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|Any CPU.Build.0 = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|ARM64.ActiveCfg = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|ARM64.Build.0 = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|x64.ActiveCfg = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|x64.Build.0 = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|x86.ActiveCfg = Release|Any CPU + {3E91CED7-5528-4B46-8722-FB95D4FAB967}.Release|x86.Build.0 = Release|Any CPU {E2477ADC-03DA-490D-9369-79A4CC4A58D2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {E2477ADC-03DA-490D-9369-79A4CC4A58D2}.Debug|Any CPU.Build.0 = Debug|Any CPU {E2477ADC-03DA-490D-9369-79A4CC4A58D2}.Debug|ARM64.ActiveCfg = Debug|Any CPU @@ -225,4 +243,7 @@ Global GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {FC108D8D-B425-4DA0-B9CC-69670BCF4835} + EndGlobalSection EndGlobal diff --git a/CoderSdk/CoderApiClient.cs b/CoderSdk/CoderApiClient.cs index 016998d..df2d923 100644 --- a/CoderSdk/CoderApiClient.cs +++ b/CoderSdk/CoderApiClient.cs @@ -2,7 +2,25 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace CoderSdk; +namespace Coder.Desktop.CoderSdk; + +public interface ICoderApiClientFactory +{ + public ICoderApiClient Create(string baseUrl); +} + +public class CoderApiClientFactory : ICoderApiClientFactory +{ + public ICoderApiClient Create(string baseUrl) + { + return new CoderApiClient(baseUrl); + } +} + +public partial interface ICoderApiClient +{ + public void SetSessionToken(string token); +} /// /// Changes names from PascalCase to snake_case. @@ -18,19 +36,26 @@ public override string ConvertName(string name) } [JsonSerializable(typeof(BuildInfo))] +[JsonSerializable(typeof(Response))] [JsonSerializable(typeof(User))] -public partial class CoderSdkJsonContext : JsonSerializerContext -{ -} +[JsonSerializable(typeof(ValidationError))] +public partial class CoderSdkJsonContext : JsonSerializerContext; /// /// Provides a limited selection of API methods for a Coder instance. /// -public partial class CoderApiClient +public partial class CoderApiClient : ICoderApiClient { + public static readonly JsonSerializerOptions JsonOptions = new() + { + TypeInfoResolver = CoderSdkJsonContext.Default, + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = new SnakeCaseNamingPolicy(), + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + // TODO: allow adding headers private readonly HttpClient _httpClient = new(); - private readonly JsonSerializerOptions _jsonOptions; public CoderApiClient(string baseUrl) : this(new Uri(baseUrl, UriKind.Absolute)) { @@ -41,13 +66,6 @@ public CoderApiClient(Uri baseUrl) if (baseUrl.PathAndQuery != "/") throw new ArgumentException($"Base URL '{baseUrl}' must not contain a path", nameof(baseUrl)); _httpClient.BaseAddress = baseUrl; - _jsonOptions = new JsonSerializerOptions - { - TypeInfoResolver = CoderSdkJsonContext.Default, - PropertyNameCaseInsensitive = true, - PropertyNamingPolicy = new SnakeCaseNamingPolicy(), - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - }; } public CoderApiClient(string baseUrl, string token) : this(baseUrl) @@ -76,22 +94,26 @@ private async Task SendRequestAsync(HttpMethod m if (payload is not null) { - var json = JsonSerializer.Serialize(payload, typeof(TRequest), _jsonOptions); + var json = JsonSerializer.Serialize(payload, typeof(TRequest), JsonOptions); request.Content = new StringContent(json, Encoding.UTF8, "application/json"); } var res = await _httpClient.SendAsync(request, ct); - // TODO: this should be improved to try and parse a codersdk.Error response - res.EnsureSuccessStatusCode(); + if (!res.IsSuccessStatusCode) + throw await CoderApiHttpException.FromResponse(res, ct); var content = await res.Content.ReadAsStringAsync(ct); - var data = JsonSerializer.Deserialize(content, _jsonOptions); + var data = JsonSerializer.Deserialize(content, JsonOptions); if (data is null) throw new JsonException("Deserialized response is null"); return data; } + catch (CoderApiHttpException) + { + throw; + } catch (Exception e) { - throw new Exception($"API Request: {method} {path} (req body: {payload is not null})", e); + throw new Exception($"Coder API Request failed: {method} {path}", e); } } } diff --git a/CoderSdk/Deployment.cs b/CoderSdk/Deployment.cs index d85a458..e95e039 100644 --- a/CoderSdk/Deployment.cs +++ b/CoderSdk/Deployment.cs @@ -1,16 +1,13 @@ -namespace CoderSdk; +namespace Coder.Desktop.CoderSdk; + +public partial interface ICoderApiClient +{ + public Task GetBuildInfo(CancellationToken ct = default); +} public class BuildInfo { - public string ExternalUrl { get; set; } = ""; public string Version { get; set; } = ""; - public string DashboardUrl { get; set; } = ""; - public bool Telemetry { get; set; } = false; - public bool WorkspaceProxy { get; set; } = false; - public string AgentApiVersion { get; set; } = ""; - public string ProvisionerApiVersion { get; set; } = ""; - public string UpgradeMessage { get; set; } = ""; - public string DeploymentId { get; set; } = ""; } public partial class CoderApiClient diff --git a/CoderSdk/Errors.cs b/CoderSdk/Errors.cs new file mode 100644 index 0000000..4d79a59 --- /dev/null +++ b/CoderSdk/Errors.cs @@ -0,0 +1,82 @@ +using System.Net; +using System.Text.Json; + +namespace Coder.Desktop.CoderSdk; + +public class ValidationError +{ + public string Field { get; set; } = ""; + public string Detail { get; set; } = ""; +} + +public class Response +{ + public string Message { get; set; } = ""; + public string Detail { get; set; } = ""; + public List Validations { get; set; } = []; +} + +public class CoderApiHttpException : Exception +{ + private static readonly Dictionary Helpers = new() + { + { HttpStatusCode.Unauthorized, "Try signing in again" }, + }; + + public readonly HttpMethod? Method; + public readonly Uri? RequestUri; + public readonly HttpStatusCode StatusCode; + public readonly string? ReasonPhrase; + public readonly Response Response; + + public CoderApiHttpException(HttpMethod? method, Uri? requestUri, HttpStatusCode statusCode, string? reasonPhrase, + Response response) : base(MessageFrom(method, requestUri, statusCode, reasonPhrase, response)) + { + Method = method; + RequestUri = requestUri; + StatusCode = statusCode; + ReasonPhrase = reasonPhrase; + Response = response; + } + + public static async Task FromResponse(HttpResponseMessage response, CancellationToken ct) + { + var content = await response.Content.ReadAsStringAsync(ct); + Response? responseObject; + try + { + responseObject = JsonSerializer.Deserialize(content, CoderApiClient.JsonOptions); + } + catch (JsonException) + { + responseObject = null; + } + + if (responseObject is null or { Message: null or "" }) + responseObject = new Response + { + Message = "Could not parse response, or response has no message", + Detail = content, + Validations = [], + }; + + return new CoderApiHttpException( + response.RequestMessage?.Method, + response.RequestMessage?.RequestUri, + response.StatusCode, + response.ReasonPhrase, + responseObject); + } + + private static string MessageFrom(HttpMethod? method, Uri? requestUri, HttpStatusCode statusCode, + string? reasonPhrase, Response response) + { + var message = $"Coder API Request: {method} '{requestUri}' failed with status code {(int)statusCode}"; + if (!string.IsNullOrEmpty(reasonPhrase)) message += $" {reasonPhrase}"; + message += $": {response.Message}"; + if (Helpers.TryGetValue(statusCode, out var helperMessage)) message += $": {helperMessage}"; + if (!string.IsNullOrEmpty(response.Detail)) message += $"\n\tError: {response.Detail}"; + foreach (var validation in response.Validations) message += $"\n\t{validation.Field}: {validation.Detail}"; + return message; + } +} diff --git a/CoderSdk/Users.cs b/CoderSdk/Users.cs index 2d99e02..fd81b32 100644 --- a/CoderSdk/Users.cs +++ b/CoderSdk/Users.cs @@ -1,10 +1,14 @@ -namespace CoderSdk; +namespace Coder.Desktop.CoderSdk; + +public partial interface ICoderApiClient +{ + public Task GetUser(string user, CancellationToken ct = default); +} public class User { public const string Me = "me"; - // TODO: fill out more fields public string Username { get; set; } = ""; } diff --git a/Tests.App/Services/CredentialManagerTest.cs b/Tests.App/Services/CredentialManagerTest.cs new file mode 100644 index 0000000..2fa4699 --- /dev/null +++ b/Tests.App/Services/CredentialManagerTest.cs @@ -0,0 +1,323 @@ +using System.Diagnostics; +using Coder.Desktop.App.Models; +using Coder.Desktop.App.Services; +using Coder.Desktop.CoderSdk; +using Moq; + +namespace Coder.Desktop.Tests.App.Services; + +[TestFixture] +public class CredentialManagerTest +{ + private const string TestServerUrl = "https://dev.coder.com"; + private const string TestApiToken = "abcdef1234-abcdef1234567890ABCDEF"; + private const string TestUsername = "dean"; + + [Test(Description = "End to end test with WindowsCredentialBackend")] + [CancelAfter(30_000)] + public async Task EndToEnd(CancellationToken ct) + { + var credentialBackend = new WindowsCredentialBackend($"Coder.Desktop.Test.App.{Guid.NewGuid()}"); + + // I lied. It's not fully end to end. We don't use a real or fake API + // server for this and use a mock client instead. + var apiClient = new Mock(MockBehavior.Strict); + apiClient.Setup(x => x.SetSessionToken(TestApiToken)); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny()).Result) + .Returns(new BuildInfo { Version = "v2.20.0" }); + apiClient.Setup(x => x.GetUser(User.Me, It.IsAny()).Result) + .Returns(new User { Username = TestUsername }); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object); + + try + { + var manager1 = new CredentialManager(credentialBackend, apiClientFactory.Object); + + // Cached credential should be unknown. + var cred = manager1.GetCachedCredentials(); + Assert.That(cred.State, Is.EqualTo(CredentialState.Unknown)); + + // Load credentials from backend. No credentials are stored so it + // should be invalid. + cred = await manager1.LoadCredentials(ct).WaitAsync(ct); + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + + // SetCredentials should succeed. + await manager1.SetCredentials(TestServerUrl, TestApiToken, ct).WaitAsync(ct); + + // Cached credential should be valid. + cred = manager1.GetCachedCredentials(); + Assert.That(cred.State, Is.EqualTo(CredentialState.Valid)); + Assert.That(cred.CoderUrl, Is.EqualTo(TestServerUrl)); + Assert.That(cred.ApiToken, Is.EqualTo(TestApiToken)); + Assert.That(cred.Username, Is.EqualTo(TestUsername)); + + // Load credentials should return the same reference. + var loadedCred = await manager1.LoadCredentials(ct).WaitAsync(ct); + Assert.That(ReferenceEquals(cred, loadedCred), Is.True); + + // A second manager should be able to load the same credentials. + var manager2 = new CredentialManager(credentialBackend, apiClientFactory.Object); + cred = await manager2.LoadCredentials(ct).WaitAsync(ct); + Assert.That(cred.State, Is.EqualTo(CredentialState.Valid)); + Assert.That(cred.CoderUrl, Is.EqualTo(TestServerUrl)); + Assert.That(cred.ApiToken, Is.EqualTo(TestApiToken)); + Assert.That(cred.Username, Is.EqualTo(TestUsername)); + + // Clearing the credentials should make them invalid. + await manager1.ClearCredentials(ct).WaitAsync(ct); + cred = manager1.GetCachedCredentials(); + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + + // And loading them in a new manager should also be invalid. + var manager3 = new CredentialManager(credentialBackend, apiClientFactory.Object); + cred = await manager3.LoadCredentials(ct).WaitAsync(ct); + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + } + finally + { + // In case something goes wrong, make sure to clean up. + using var cts = new CancellationTokenSource(); + cts.CancelAfter(15_000); + await credentialBackend.DeleteCredentials(cts.Token); + } + } + + [Test(Description = "Test SetCredentials with invalid URL or token")] + [CancelAfter(30_000)] + public void SetCredentialsInvalidUrlOrToken(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + var apiClientFactory = new Mock(MockBehavior.Strict); + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + + var cases = new List<(string, string, string)> + { + (null!, TestApiToken, "Coder URL is required"), + ("", TestApiToken, "Coder URL is required"), + (" ", TestApiToken, "Coder URL is required"), + (new string('a', 129), TestApiToken, "Coder URL is too long"), + ("a", TestApiToken, "not a valid URL"), + ("ftp://dev.coder.com", TestApiToken, "Coder URL must be HTTP or HTTPS"), + + (TestServerUrl, null!, "API token is required"), + (TestServerUrl, "", "API token is required"), + (TestServerUrl, " ", "API token is required"), + }; + + foreach (var (url, token, expectedMessage) in cases) + { + var ex = Assert.ThrowsAsync(() => + manager.SetCredentials(url, token, ct)); + Assert.That(ex.Message, Does.Contain(expectedMessage)); + } + } + + [Test(Description = "Invalid server buildinfo response")] + [CancelAfter(30_000)] + public void InvalidServerBuildInfoResponse(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + var apiClient = new Mock(MockBehavior.Strict); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny()).Result) + .Throws(new Exception("Test exception")); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object); + + // Attempt a set. + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + var ex = Assert.ThrowsAsync(() => + manager.SetCredentials(TestServerUrl, TestApiToken, ct)); + Assert.That(ex.Message, Does.Contain("Could not connect to or verify Coder server")); + + // Attempt a load. + credentialBackend.Setup(x => x.ReadCredentials(It.IsAny()).Result) + .Returns(new RawCredentials + { + CoderUrl = TestServerUrl, + ApiToken = TestApiToken, + }); + var cred = manager.LoadCredentials(ct).Result; + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + } + + [Test(Description = "Invalid server version")] + [CancelAfter(30_000)] + public void InvalidServerVersion(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + var apiClient = new Mock(MockBehavior.Strict); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny()).Result) + .Returns(new BuildInfo { Version = "v2.19.0" }); + apiClient.Setup(x => x.SetSessionToken(TestApiToken)); + apiClient.Setup(x => x.GetUser(User.Me, It.IsAny()).Result) + .Returns(new User { Username = TestUsername }); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object); + + // Attempt a set. + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + var ex = Assert.ThrowsAsync(() => + manager.SetCredentials(TestServerUrl, TestApiToken, ct)); + Assert.That(ex.Message, Does.Contain("not within required server version range")); + + // Attempt a load. + credentialBackend.Setup(x => x.ReadCredentials(It.IsAny()).Result) + .Returns(new RawCredentials + { + CoderUrl = TestServerUrl, + ApiToken = TestApiToken, + }); + var cred = manager.LoadCredentials(ct).Result; + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + } + + [Test(Description = "Invalid server user response")] + [CancelAfter(30_000)] + public void InvalidServerUserResponse(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + var apiClient = new Mock(MockBehavior.Strict); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny()).Result) + .Returns(new BuildInfo { Version = "v2.20.0" }); + apiClient.Setup(x => x.SetSessionToken(TestApiToken)); + apiClient.Setup(x => x.GetUser(User.Me, It.IsAny()).Result) + .Throws(new Exception("Test exception")); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object); + + // Attempt a set. + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + var ex = Assert.ThrowsAsync(() => + manager.SetCredentials(TestServerUrl, TestApiToken, ct)); + Assert.That(ex.Message, Does.Contain("Could not connect to or verify Coder server")); + + // Attempt a load. + credentialBackend.Setup(x => x.ReadCredentials(It.IsAny()).Result) + .Returns(new RawCredentials + { + CoderUrl = TestServerUrl, + ApiToken = TestApiToken, + }); + var cred = manager.LoadCredentials(ct).Result; + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + } + + [Test(Description = "Invalid username")] + [CancelAfter(30_000)] + public void InvalidUsername(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + var apiClient = new Mock(MockBehavior.Strict); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny()).Result) + .Returns(new BuildInfo { Version = "v2.20.0" }); + apiClient.Setup(x => x.SetSessionToken(TestApiToken)); + apiClient.Setup(x => x.GetUser(User.Me, It.IsAny()).Result) + .Returns(new User { Username = "" }); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object); + + // Attempt a set. + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + var ex = Assert.ThrowsAsync(() => + manager.SetCredentials(TestServerUrl, TestApiToken, ct)); + Assert.That(ex.Message, Does.Contain("username is empty")); + + // Attempt a load. + credentialBackend.Setup(x => x.ReadCredentials(It.IsAny()).Result) + .Returns(new RawCredentials + { + CoderUrl = TestServerUrl, + ApiToken = TestApiToken, + }); + var cred = manager.LoadCredentials(ct).Result; + Assert.That(cred.State, Is.EqualTo(CredentialState.Invalid)); + } + + [Test(Description = "Duplicate loads should use the same Task")] + [CancelAfter(30_000)] + public async Task DuplicateLoads(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + credentialBackend.Setup(x => x.ReadCredentials(It.IsAny()).Result) + .Returns(new RawCredentials + { + CoderUrl = TestServerUrl, + ApiToken = TestApiToken, + }) + .Verifiable(Times.Exactly(1)); + var apiClient = new Mock(MockBehavior.Strict); + // To accomplish delay, the GetBuildInfo will wait for a TCS. + var tcs = new TaskCompletionSource(); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny())) + .Returns(async (CancellationToken _) => + { + await tcs.Task.WaitAsync(ct); + return new BuildInfo { Version = "v2.20.0" }; + }) + .Verifiable(Times.Exactly(1)); + apiClient.Setup(x => x.SetSessionToken(TestApiToken)); + apiClient.Setup(x => x.GetUser(User.Me, It.IsAny()).Result) + .Returns(new User { Username = TestUsername }) + .Verifiable(Times.Exactly(1)); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object) + .Verifiable(Times.Exactly(1)); + + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + var cred1Task = manager.LoadCredentials(ct); + var cred2Task = manager.LoadCredentials(ct); + Assert.That(ReferenceEquals(cred1Task, cred2Task), Is.True); + tcs.SetResult(); + var cred1 = await cred1Task.WaitAsync(ct); + var cred2 = await cred2Task.WaitAsync(ct); + Assert.That(ReferenceEquals(cred1, cred2), Is.True); + + credentialBackend.Verify(); + apiClient.Verify(); + apiClientFactory.Verify(); + } + + [Test(Description = "A set during a load should cancel the load")] + [CancelAfter(30_000)] + public async Task SetDuringLoad(CancellationToken ct) + { + var credentialBackend = new Mock(MockBehavior.Strict); + // To accomplish a delay on the load, ReadCredentials will block on the CT. + credentialBackend.Setup(x => x.ReadCredentials(It.IsAny())) + .Returns(async (CancellationToken innerCt) => + { + await Task.Delay(Timeout.Infinite, innerCt).WaitAsync(ct); + throw new UnreachableException(); + }); + credentialBackend.Setup(x => + x.WriteCredentials( + It.Is(c => c.CoderUrl == TestServerUrl && c.ApiToken == TestApiToken), + It.IsAny())) + .Returns(Task.CompletedTask); + var apiClient = new Mock(MockBehavior.Strict); + apiClient.Setup(x => x.GetBuildInfo(It.IsAny()).Result) + .Returns(new BuildInfo { Version = "v2.20.0" }); + apiClient.Setup(x => x.SetSessionToken(TestApiToken)); + apiClient.Setup(x => x.GetUser(User.Me, It.IsAny()).Result) + .Returns(new User { Username = TestUsername }); + var apiClientFactory = new Mock(MockBehavior.Strict); + apiClientFactory.Setup(x => x.Create(TestServerUrl)) + .Returns(apiClient.Object); + + var manager = new CredentialManager(credentialBackend.Object, apiClientFactory.Object); + // Start a load... + var loadTask = manager.LoadCredentials(ct); + // Then fully perform a set. + await manager.SetCredentials(TestServerUrl, TestApiToken, ct).WaitAsync(ct); + // The load should have been cancelled. + Assert.ThrowsAsync(() => loadTask); + } +} diff --git a/Tests.App/Tests.App.csproj b/Tests.App/Tests.App.csproj new file mode 100644 index 0000000..cc01512 --- /dev/null +++ b/Tests.App/Tests.App.csproj @@ -0,0 +1,38 @@ + + + + Coder.Desktop.Tests.App + Coder.Desktop.Tests.App + net8.0-windows10.0.19041.0 + preview + enable + enable + true + + false + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + diff --git a/Tests.Vpn.Service/packages.lock.json b/Tests.Vpn.Service/packages.lock.json index 9becace..45e0457 100644 --- a/Tests.Vpn.Service/packages.lock.json +++ b/Tests.Vpn.Service/packages.lock.json @@ -474,6 +474,7 @@ "type": "Project", "dependencies": { "Coder.Desktop.Vpn.Proto": "[1.0.0, )", + "Semver": "[3.0.0, )", "System.IO.Pipelines": "[9.0.1, )" } }, diff --git a/Tests.Vpn/Utilities/ServerVersionUtilitiesTest.cs b/Tests.Vpn/Utilities/ServerVersionUtilitiesTest.cs new file mode 100644 index 0000000..ac96013 --- /dev/null +++ b/Tests.Vpn/Utilities/ServerVersionUtilitiesTest.cs @@ -0,0 +1,74 @@ +using Coder.Desktop.Vpn.Utilities; +using Semver; + +namespace Coder.Desktop.Tests.Vpn.Utilities; + +[TestFixture] +public class ServerVersionUtilitiesTest +{ + [Test(Description = "Test invalid versions")] + public void InvalidVersions() + { + var invalidVersions = new List<(string, string)> + { + (null!, "Server version is empty"), + ("", "Server version is empty"), + (" ", "Server version is empty"), + ("v", "Could not parse server version"), + ("1", "Could not parse server version"), + ("v1", "Could not parse server version"), + ("1.2", "Could not parse server version"), + ("v1.2", "Could not parse server version"), + ("1.2.3.4", "Could not parse server version"), + ("v1.2.3.4", "Could not parse server version"), + + ("1.2.3", "not within required server version range"), + ("v1.2.3", "not within required server version range"), + ("2.19.0-devel", "not within required server version range"), + ("v2.19.0-devel", "not within required server version range"), + }; + + foreach (var (version, expectedErrorMessage) in invalidVersions) + { + var ex = Assert.Throws(() => + ServerVersionUtilities.ParseAndValidateServerVersion(version)); + Assert.That(ex.Message, Does.Contain(expectedErrorMessage)); + } + } + + [Test(Description = "Test valid versions")] + public void ValidVersions() + { + var validVersions = new List + { + new() + { + RawString = "2.20.0-devel+17f8e93d0", + SemVersion = new SemVersion(2, 20, 0, ["devel"], ["17f8e93d0"]), + }, + new() + { + RawString = "2.20.0", + SemVersion = new SemVersion(2, 20, 0), + }, + new() + { + RawString = "2.21.3", + SemVersion = new SemVersion(2, 21, 3), + }, + new() + { + RawString = "3.0.0", + SemVersion = new SemVersion(3, 0, 0), + }, + }; + + foreach (var version in validVersions) + foreach (var prefix in new[] { "", "v" }) + { + var result = ServerVersionUtilities.ParseAndValidateServerVersion(prefix + version.RawString); + Assert.That(result.RawString, Is.EqualTo(prefix + version.RawString), version.RawString); + Assert.That(result.SemVersion, Is.EqualTo(version.SemVersion), version.RawString); + } + } +} diff --git a/Tests.Vpn/packages.lock.json b/Tests.Vpn/packages.lock.json index 1ba2a4f..10f6f62 100644 --- a/Tests.Vpn/packages.lock.json +++ b/Tests.Vpn/packages.lock.json @@ -46,6 +46,11 @@ "resolved": "17.12.0", "contentHash": "4svMznBd5JM21JIG2xZKGNanAHNXplxf/kQDFfLHXQ3OnpJkayRK/TjacFjA+EYmoyuNXHo/sOETEfcYtAzIrA==" }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "5.0.1", + "contentHash": "5WPSmL4YeP7eW+Vc8XZ4DwjYWBAiSwDV9Hm63JJWcz1Ie3Xjv4KuJXzgCstj48LkLfVCYa7mLcx7y+q6yqVvtw==" + }, "Microsoft.TestPlatform.ObjectModel": { "type": "Transitive", "resolved": "17.12.0", @@ -68,6 +73,14 @@ "resolved": "13.0.1", "contentHash": "ppPFpBcvxdsfUonNcvITKqLl3bqxWbDCZIzDWHzjpdAHRFfZe0Dw9HmA0+za13IdyrgJwpkDTDA9fHaxOrt20A==" }, + "Semver": { + "type": "Transitive", + "resolved": "3.0.0", + "contentHash": "9jZCicsVgTebqkAujRWtC9J1A5EQVlu0TVKHcgoCuv345ve5DYf4D1MjhKEnQjdRZo6x/vdv6QQrYFs7ilGzLA==", + "dependencies": { + "Microsoft.Extensions.Primitives": "5.0.1" + } + }, "System.IO.Pipelines": { "type": "Transitive", "resolved": "9.0.1", @@ -82,6 +95,7 @@ "type": "Project", "dependencies": { "Coder.Desktop.Vpn.Proto": "[1.0.0, )", + "Semver": "[3.0.0, )", "System.IO.Pipelines": "[9.0.1, )" } }, diff --git a/Vpn.DebugClient/packages.lock.json b/Vpn.DebugClient/packages.lock.json index 93925e9..403a41b 100644 --- a/Vpn.DebugClient/packages.lock.json +++ b/Vpn.DebugClient/packages.lock.json @@ -7,6 +7,19 @@ "resolved": "3.29.3", "contentHash": "t7nZFFUFwigCwZ+nIXHDLweXvwIpsOXi+P7J7smPT/QjI3EKxnCzTQOhBqyEh6XEzc/pNH+bCFOOSjatrPt6Tw==" }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "5.0.1", + "contentHash": "5WPSmL4YeP7eW+Vc8XZ4DwjYWBAiSwDV9Hm63JJWcz1Ie3Xjv4KuJXzgCstj48LkLfVCYa7mLcx7y+q6yqVvtw==" + }, + "Semver": { + "type": "Transitive", + "resolved": "3.0.0", + "contentHash": "9jZCicsVgTebqkAujRWtC9J1A5EQVlu0TVKHcgoCuv345ve5DYf4D1MjhKEnQjdRZo6x/vdv6QQrYFs7ilGzLA==", + "dependencies": { + "Microsoft.Extensions.Primitives": "5.0.1" + } + }, "System.IO.Pipelines": { "type": "Transitive", "resolved": "9.0.1", @@ -16,6 +29,7 @@ "type": "Project", "dependencies": { "Coder.Desktop.Vpn.Proto": "[1.0.0, )", + "Semver": "[3.0.0, )", "System.IO.Pipelines": "[9.0.1, )" } }, diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index 9e70b31..3345e98 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -1,7 +1,7 @@ using System.Runtime.InteropServices; +using Coder.Desktop.CoderSdk; using Coder.Desktop.Vpn.Proto; using Coder.Desktop.Vpn.Utilities; -using CoderSdk; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Semver; @@ -16,12 +16,6 @@ public enum TunnelStatus Stopped, } -public class ServerVersion -{ - public required string String { get; set; } - public required SemVersion SemVersion { get; set; } -} - public interface IManager : IDisposable { public Task StopAsync(CancellationToken ct = default); @@ -32,9 +26,6 @@ public interface IManager : IDisposable /// public class Manager : IManager { - // TODO: determine a suitable value for this - private static readonly SemVersionRange ServerVersionRange = SemVersionRange.All; - private readonly ManagerConfig _config; private readonly IDownloader _downloader; private readonly ILogger _logger; @@ -141,7 +132,7 @@ private async ValueTask HandleClientMessageStart(ClientMessage me var serverVersion = await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct); if (_status == TunnelStatus.Started && _lastStartRequest != null && - _lastStartRequest.Equals(message.Start) && _lastServerVersion?.String == serverVersion.String) + _lastStartRequest.Equals(message.Start) && _lastServerVersion?.RawString == serverVersion.RawString) { // The client is requesting to start an identical tunnel while // we're already running it. @@ -373,20 +364,11 @@ private async ValueTask CheckServerVersionAndCredentials(string b var buildInfo = await client.GetBuildInfo(ct); _logger.LogInformation("Fetched server version '{ServerVersion}'", buildInfo.Version); - if (buildInfo.Version.StartsWith('v')) buildInfo.Version = buildInfo.Version[1..]; - var serverVersion = SemVersion.Parse(buildInfo.Version); - if (!serverVersion.Satisfies(ServerVersionRange)) - throw new InvalidOperationException( - $"Server version '{serverVersion}' is not within required server version range '{ServerVersionRange}'"); - + var serverVersion = ServerVersionUtilities.ParseAndValidateServerVersion(buildInfo.Version); var user = await client.GetUser(User.Me, ct); _logger.LogInformation("Authenticated to server as '{Username}'", user.Username); - return new ServerVersion - { - String = buildInfo.Version, - SemVersion = serverVersion, - }; + return serverVersion; } /// diff --git a/Vpn.Service/packages.lock.json b/Vpn.Service/packages.lock.json index b2fba99..ace2cdb 100644 --- a/Vpn.Service/packages.lock.json +++ b/Vpn.Service/packages.lock.json @@ -416,6 +416,7 @@ "type": "Project", "dependencies": { "Coder.Desktop.Vpn.Proto": "[1.0.0, )", + "Semver": "[3.0.0, )", "System.IO.Pipelines": "[9.0.1, )" } }, diff --git a/Vpn/Utilities/RaiiSemaphoreSlim.cs b/Vpn/Utilities/RaiiSemaphoreSlim.cs index e38db6a..25f12bc 100644 --- a/Vpn/Utilities/RaiiSemaphoreSlim.cs +++ b/Vpn/Utilities/RaiiSemaphoreSlim.cs @@ -30,13 +30,13 @@ public IDisposable Lock() return new Locker(_semaphore); } - public async ValueTask LockAsync(CancellationToken ct = default) + public async Task LockAsync(CancellationToken ct = default) { await _semaphore.WaitAsync(ct); return new Locker(_semaphore); } - public async ValueTask LockAsync(TimeSpan timeout, CancellationToken ct = default) + public async Task LockAsync(TimeSpan timeout, CancellationToken ct = default) { if (!await _semaphore.WaitAsync(timeout, ct)) return null; return new Locker(_semaphore); @@ -44,16 +44,16 @@ public async ValueTask LockAsync(CancellationToken ct = default) private class Locker : IDisposable { - private readonly SemaphoreSlim _semaphore1; + private readonly SemaphoreSlim _semaphore; public Locker(SemaphoreSlim semaphore) { - _semaphore1 = semaphore; + _semaphore = semaphore; } public void Dispose() { - _semaphore1.Release(); + _semaphore.Release(); GC.SuppressFinalize(this); } } diff --git a/Vpn/Utilities/ServerVersionUtilities.cs b/Vpn/Utilities/ServerVersionUtilities.cs new file mode 100644 index 0000000..88bca69 --- /dev/null +++ b/Vpn/Utilities/ServerVersionUtilities.cs @@ -0,0 +1,45 @@ +using Semver; + +namespace Coder.Desktop.Vpn.Utilities; + +public class ServerVersion +{ + public required string RawString { get; set; } + public required SemVersion SemVersion { get; set; } +} + +public static class ServerVersionUtilities +{ + // The -0 allows pre-release versions. + private static readonly SemVersionRange ServerVersionRange = SemVersionRange.Parse(">= 2.20.0-0", + SemVersionRangeOptions.IncludeAllPrerelease | SemVersionRangeOptions.AllowV | + SemVersionRangeOptions.AllowMetadata); + + /// + /// Attempts to parse and verify that the server version is within the supported range. + /// + /// + /// The server version to check, optionally with a leading `v` or extra metadata/pre-release + /// tags + /// + /// The parsed server version + /// Could not parse version + /// The server version is not in range + public static ServerVersion ParseAndValidateServerVersion(string versionString) + { + if (string.IsNullOrWhiteSpace(versionString)) + throw new ArgumentException("Server version is empty", nameof(versionString)); + if (!SemVersion.TryParse(versionString, SemVersionStyles.AllowV, out var version)) + throw new ArgumentException($"Could not parse server version '{versionString}'", nameof(versionString)); + if (!version.Satisfies(ServerVersionRange)) + throw new ArgumentException( + $"Server version '{version}' is not within required server version range '{ServerVersionRange}'", + nameof(versionString)); + + return new ServerVersion + { + RawString = versionString, + SemVersion = version, + }; + } +} diff --git a/Vpn/Vpn.csproj b/Vpn/Vpn.csproj index e8016f3..c08b669 100644 --- a/Vpn/Vpn.csproj +++ b/Vpn/Vpn.csproj @@ -14,6 +14,7 @@ + diff --git a/Vpn/packages.lock.json b/Vpn/packages.lock.json index c62e288..5eca812 100644 --- a/Vpn/packages.lock.json +++ b/Vpn/packages.lock.json @@ -2,6 +2,15 @@ "version": 1, "dependencies": { "net8.0": { + "Semver": { + "type": "Direct", + "requested": "[3.0.0, )", + "resolved": "3.0.0", + "contentHash": "9jZCicsVgTebqkAujRWtC9J1A5EQVlu0TVKHcgoCuv345ve5DYf4D1MjhKEnQjdRZo6x/vdv6QQrYFs7ilGzLA==", + "dependencies": { + "Microsoft.Extensions.Primitives": "5.0.1" + } + }, "System.IO.Pipelines": { "type": "Direct", "requested": "[9.0.1, )", @@ -13,6 +22,11 @@ "resolved": "3.29.3", "contentHash": "t7nZFFUFwigCwZ+nIXHDLweXvwIpsOXi+P7J7smPT/QjI3EKxnCzTQOhBqyEh6XEzc/pNH+bCFOOSjatrPt6Tw==" }, + "Microsoft.Extensions.Primitives": { + "type": "Transitive", + "resolved": "5.0.1", + "contentHash": "5WPSmL4YeP7eW+Vc8XZ4DwjYWBAiSwDV9Hm63JJWcz1Ie3Xjv4KuJXzgCstj48LkLfVCYa7mLcx7y+q6yqVvtw==" + }, "Coder.Desktop.Vpn.Proto": { "type": "Project", "dependencies": {