Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AAD fallback if key authentication is disabled #2290

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 76 additions & 41 deletions src/tree/SubscriptionTreeItem.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import { CosmosDBManagementClient } from '@azure/arm-cosmosdb';
import { DatabaseAccountGetResults, DatabaseAccountListKeysResult } from '@azure/arm-cosmosdb/src/models';
import { ILocationWizardContext, LocationListStep, ResourceGroupListStep, SubscriptionTreeItemBase, getResourceGroupFromId, uiUtils } from '@microsoft/vscode-azext-azureutils';
import { AzExtParentTreeItem, AzExtTreeItem, AzureWizard, AzureWizardPromptStep, IActionContext } from '@microsoft/vscode-azext-utils';
import { AzExtParentTreeItem, AzExtTreeItem, AzureWizard, AzureWizardPromptStep, IActionContext, callWithTelemetryAndErrorHandling } from '@microsoft/vscode-azext-utils';
import * as vscode from 'vscode';
import { API, Experience, getExperienceLabel, tryGetExperience } from '../AzureDBExperiences';
import { CosmosDBCredential } from '../docdb/getCosmosClient';
import { CosmosDBCredential, CosmosDBKeyCredential } from '../docdb/getCosmosClient';
import { DocDBAccountTreeItem } from "../docdb/tree/DocDBAccountTreeItem";
import { ext } from '../extensionVariables';
import { tryGetGremlinEndpointFromAzure } from '../graph/gremlinEndpoints';
Expand Down Expand Up @@ -121,51 +121,86 @@ export class SubscriptionTreeItem extends SubscriptionTreeItemBase {
const label: string = name + (accountKindLabel ? ` (${accountKindLabel})` : ``);
const isEmulator: boolean = false;

if (experience && experience.api === "MongoDB") {
const result = await client.databaseAccounts.listConnectionStrings(resourceGroup, name);
const connectionString: URL = new URL(nonNullProp(nonNullProp(result, 'connectionStrings')[0], 'connectionString'));
// for any Mongo connectionString, append this query param because the Cosmos Mongo API v3.6 doesn't support retrywrites
// but the newer node.js drivers started breaking this
const searchParam: string = 'retrywrites';
if (!connectionString.searchParams.has(searchParam)) {
connectionString.searchParams.set(searchParam, 'false');
}

// Use the default connection string
return new MongoAccountTreeItem(parent, id, label, connectionString.toString(), isEmulator, databaseAccount);
} else {
let keyResult: DatabaseAccountListKeysResult | undefined;
try {
keyResult = await client.databaseAccounts.listKeys(resourceGroup, name);
} catch (error) {
// If the client failed to list keys, proceed without using keys
}
const newNode = await callWithTelemetryAndErrorHandling('cosmosDB.initCosmosDBChild', async (context: IActionContext) => {
// leave error handling to the caller (command or tree node)
context.errorHandling.suppressDisplay = true;
// rethrow all errors to satisfy initCosmosDBChild contract
context.errorHandling.rethrow = true;
context.telemetry.properties.experience = experience?.api;

if (experience && experience.api === "MongoDB") {
const result = await client.databaseAccounts.listConnectionStrings(resourceGroup, name);
const connectionString: URL = new URL(nonNullProp(nonNullProp(result, 'connectionStrings')[0], 'connectionString'));
// for any Mongo connectionString, append this query param because the Cosmos Mongo API v3.6 doesn't support retrywrites
// but the newer node.js drivers started breaking this
const searchParam: string = 'retrywrites';
if (!connectionString.searchParams.has(searchParam)) {
connectionString.searchParams.set(searchParam, 'false');
}

let keyCred = keyResult?.primaryMasterKey ? {
type: "key",
key: keyResult.primaryMasterKey
} : undefined;
const testCosmosAuth = vscode.workspace.getConfiguration().get<boolean>("azureDatabases.useCosmosOAuth");
if (testCosmosAuth) {
keyCred = undefined;
}
const authCred = { type: "auth" };
const credentials = [keyCred, authCred].filter((cred): cred is CosmosDBCredential => cred !== undefined);
switch (experience && experience.api) {
case "Table":
return new TableAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);
case "Graph": {
const gremlinEndpoint = await tryGetGremlinEndpointFromAzure(client, resourceGroup, name);
return new GraphAccountTreeItem(parent, id, label, documentEndpoint, gremlinEndpoint, credentials, isEmulator, databaseAccount);
// Use the default connection string
return new MongoAccountTreeItem(parent, id, label, connectionString.toString(), isEmulator, databaseAccount);
} else {
let keyCred: CosmosDBKeyCredential | undefined = undefined;

const forceOAuth = vscode.workspace.getConfiguration().get<boolean>("azureDatabases.useCosmosOAuth");
context.telemetry.properties.useCosmosOAuth = (forceOAuth ?? false).toString();

// disable key auth if the user has opted in to OAuth (AAD/Entra ID)
if (!forceOAuth) {
try {
const acc = await client.databaseAccounts.get(resourceGroup, name);
const localAuthDisabled = acc.disableLocalAuth === true;
context.telemetry.properties.localAuthDisabled = localAuthDisabled.toString();
let keyResult: DatabaseAccountListKeysResult | undefined;
// If the account has local auth disabled, don't even try to use key auth
if (!localAuthDisabled) {
keyResult = await client.databaseAccounts.listKeys(resourceGroup, name);
keyCred = keyResult?.primaryMasterKey ? {
type: "key",
key: keyResult.primaryMasterKey
} : undefined;
context.telemetry.properties.receivedKeyCreds = "true";
} else {
throw new Error("Local auth is disabled");
}
} catch (error) {
context.telemetry.properties.receivedKeyCreds = "false";
const message = localize("keyPermissionErrorMsg", "You do not have the required permissions to list auth keys for [{0}].\nFalling back to using Entra ID.\nYou can change the default authentication in the settings.", name);
const openSettingsItem = localize("openSettings", "Open Settings");
void vscode.window.showWarningMessage(message, ...[openSettingsItem]).then((item) => {
if (item === openSettingsItem) {
void vscode.commands.executeCommand('workbench.action.openSettings', 'azureDatabases.useCosmosOAuth');
}
});
}
}
case "Core":
default:
// Default to DocumentDB, the base type for all Cosmos DB Accounts
return new DocDBAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);

// OAuth is always enabled for Cosmos DB and will be used as a fall back if key auth is unavailable
const authCred = { type: "auth" };
const credentials = [keyCred, authCred].filter((cred): cred is CosmosDBCredential => cred !== undefined);
switch (experience && experience.api) {
case "Table":
return new TableAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);
case "Graph": {
const gremlinEndpoint = await tryGetGremlinEndpointFromAzure(client, resourceGroup, name);
return new GraphAccountTreeItem(parent, id, label, documentEndpoint, gremlinEndpoint, credentials, isEmulator, databaseAccount);
}
case "Core":
default:
// Default to DocumentDB, the base type for all Cosmos DB Accounts
return new DocDBAccountTreeItem(parent, id, label, documentEndpoint, credentials, isEmulator, databaseAccount);

}
}
});
if (!(newNode instanceof AzExtTreeItem)) {
// note: this should never happen, callWithTelemetryAndErrorHandling will rethrow all errors
throw new Error(localize('invalidCosmosDBAccount', 'Invalid Cosmos DB account.'));
}
return newNode;
}

public static async initPostgresChild(server: PostgresAbstractServer, parent: AzExtParentTreeItem): Promise<AzExtTreeItem> {
const connectionString: string = createPostgresConnectionString(nonNullProp(server, 'fullyQualifiedDomainName'));
const parsedCS: ParsedPostgresConnectionString = parsePostgresConnectionString(connectionString);
Expand Down
Loading