Skip to content

Commit

Permalink
bulk-cdk: StreamIdentifier (#45680)
Browse files Browse the repository at this point in the history
  • Loading branch information
postamar authored Sep 19, 2024
1 parent 0a24c13 commit 5daf229
Show file tree
Hide file tree
Showing 27 changed files with 224 additions and 208 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk

import io.airbyte.protocol.models.v0.AirbyteStream
import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair
import io.airbyte.protocol.models.v0.StreamDescriptor

/**
* [StreamIdentifier] is equivalent to [AirbyteStreamNameNamespacePair].
*
* This exists to avoid coupling the Bulk CDK code too closely to the Airbyte Protocol objects.
*/
data class StreamIdentifier
private constructor(
val namespace: String?,
val name: String,
) {
companion object {
fun from(desc: StreamDescriptor): StreamIdentifier =
StreamIdentifier(desc.namespace, desc.name)

fun from(stream: AirbyteStream): StreamIdentifier =
StreamIdentifier(stream.namespace, stream.name)
}

override fun toString(): String {
return if (namespace == null) name else "${namespace}.${name}"
}
}

fun StreamIdentifier.asProtocolStreamDescriptor(): StreamDescriptor =
StreamDescriptor().withName(name).withNamespace(namespace)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
package io.airbyte.cdk.command

import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.StreamNamePair
import io.airbyte.cdk.StreamIdentifier

/** Union type of the state passed as input to a READ for a source connector. */
sealed interface InputState
Expand All @@ -11,13 +11,13 @@ data object EmptyInputState : InputState

data class GlobalInputState(
val global: OpaqueStateValue,
val globalStreams: Map<StreamNamePair, OpaqueStateValue>,
val globalStreams: Map<StreamIdentifier, OpaqueStateValue>,
/** Conceivably, some streams may undergo a full refresh alongside independently of the rest. */
val nonGlobalStreams: Map<StreamNamePair, OpaqueStateValue>,
val nonGlobalStreams: Map<StreamIdentifier, OpaqueStateValue>,
) : InputState

data class StreamInputState(
val streams: Map<StreamNamePair, OpaqueStateValue>,
val streams: Map<StreamIdentifier, OpaqueStateValue>,
) : InputState

/** State values are opaque for the CDK, the schema is owned by the connector. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ package io.airbyte.cdk.command

import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.ConfigErrorException
import io.airbyte.cdk.StreamNamePair
import io.airbyte.cdk.asStreamNamePair
import io.airbyte.cdk.StreamIdentifier
import io.airbyte.cdk.util.Jsons
import io.airbyte.cdk.util.ResourceUtils
import io.airbyte.protocol.models.v0.AirbyteGlobalState
Expand Down Expand Up @@ -49,7 +48,7 @@ class InputStateFactory {
if (msg.stream == null) {
msg.type.toString()
} else {
msg.stream.streamDescriptor.asStreamNamePair().toString()
StreamIdentifier.from(msg.stream.streamDescriptor).toString()
}
}
.mapNotNull { (groupKey, groupValues) ->
Expand All @@ -61,7 +60,7 @@ class InputStateFactory {
}
groupValues.last()
}
val nonGlobalStreams: Map<StreamNamePair, OpaqueStateValue> =
val nonGlobalStreams: Map<StreamIdentifier, OpaqueStateValue> =
streamStates(deduped.mapNotNull { it.stream })
val globalState: AirbyteGlobalState? =
deduped.find { it.type == AirbyteStateMessage.AirbyteStateType.GLOBAL }?.global
Expand All @@ -73,18 +72,18 @@ class InputStateFactory {
globalState.sharedState,
OpaqueStateValue::class.java,
)
val globalStreams: Map<StreamNamePair, OpaqueStateValue> =
val globalStreams: Map<StreamIdentifier, OpaqueStateValue> =
streamStates(globalState.streamStates)
return GlobalInputState(globalStateValue, globalStreams, nonGlobalStreams)
}

private fun streamStates(
streamStates: List<AirbyteStreamState>?,
): Map<StreamNamePair, OpaqueStateValue> =
): Map<StreamIdentifier, OpaqueStateValue> =
(streamStates ?: listOf()).associate { msg: AirbyteStreamState ->
val key: StreamNamePair = msg.streamDescriptor.asStreamNamePair()
val streamID: StreamIdentifier = StreamIdentifier.from(msg.streamDescriptor)
val jsonValue: JsonNode = msg.streamState ?: Jsons.objectNode()
key to ValidatedJsonUtils.parseUnvalidated(jsonValue, OpaqueStateValue::class.java)
streamID to ValidatedJsonUtils.parseUnvalidated(jsonValue, OpaqueStateValue::class.java)
}

private fun validateStateMessage(message: AirbyteStateMessage) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
/* Copyright (c) 2024 Airbyte, Inc., all rights reserved. */
package io.airbyte.cdk.command

import io.airbyte.cdk.StreamNamePair
import io.airbyte.cdk.StreamIdentifier
import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.StreamDescriptor
import io.micronaut.context.annotation.Property
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Inject
Expand All @@ -27,9 +28,9 @@ class InputStateTest {
val expected =
StreamInputState(
mapOf(
StreamNamePair("bar", "foo") to
streamID("foo", "bar") to
Jsons.readTree("{\"primary_key\":{\"k1\":10,\"k2\":20}}"),
StreamNamePair("baz", "foo") to Jsons.readTree("{\"cursors\":{\"c\":30}}"),
streamID("foo", "baz") to Jsons.readTree("{\"cursors\":{\"c\":30}}"),
),
)
Assertions.assertEquals(
Expand All @@ -49,7 +50,7 @@ class InputStateTest {
global = Jsons.readTree("{\"cdc\":{}}"),
globalStreams =
mapOf(
StreamNamePair("bar", "foo") to
streamID("foo", "bar") to
Jsons.readTree("{\"primary_key\":{\"k1\":10,\"k2\":20}}"),
),
nonGlobalStreams = mapOf(),
Expand All @@ -71,13 +72,12 @@ class InputStateTest {
global = Jsons.readTree("{\"cdc\":{}}"),
globalStreams =
mapOf(
StreamNamePair("bar", "foo") to
streamID("foo", "bar") to
Jsons.readTree("{\"primary_key\":{\"k1\":10,\"k2\":20}}"),
),
nonGlobalStreams =
mapOf(
StreamNamePair("baz", "foo") to
Jsons.readTree("{\"primary_key\":{\"k\":1}}"),
streamID("foo", "baz") to Jsons.readTree("{\"primary_key\":{\"k\":1}}"),
),
)
Assertions.assertEquals(
Expand All @@ -94,18 +94,20 @@ class InputStateTest {
global = Jsons.readTree("{\"cdc\":{}}"),
globalStreams =
mapOf(
StreamNamePair("bar", "foo") to
streamID("foo", "bar") to
Jsons.readTree("{\"primary_key\":{\"k1\":10,\"k2\":20}}"),
),
nonGlobalStreams =
mapOf(
StreamNamePair("baz", "foo") to
Jsons.readTree("{\"primary_key\":{\"k\":10}}"),
streamID("foo", "baz") to Jsons.readTree("{\"primary_key\":{\"k\":10}}"),
),
)
Assertions.assertEquals(
Jsons.writeValueAsString(expected),
Jsons.writeValueAsString(actual),
)
}

fun streamID(namespace: String, name: String): StreamIdentifier =
StreamIdentifier.from(StreamDescriptor().withName(name).withNamespace(namespace))
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,19 @@ class CheckOperation<T : ConfigurationJsonObjectBase>(
var n = 0
val namespaces: List<String?> = listOf<String?>(null) + metadataQuerier.streamNamespaces()
for (namespace in namespaces) {
for (name in metadataQuerier.streamNames(namespace)) {
for (streamID in metadataQuerier.streamNames(namespace)) {
try {
metadataQuerier.fields(name, namespace)
metadataQuerier.fields(streamID)
} catch (e: Exception) {
log.info(e) {
"Query failed on stream '$name' in '${namespace ?: ""}': ${e.message}"
"Query failed on stream '${streamID.name}' in '${namespace ?: ""}': ${e.message}"
}
n++
continue
}
log.info { "Query successful on stream '$name' in '${namespace ?: ""}'." }
log.info {
"Query successful on stream '${streamID.name}' in '${namespace ?: ""}'."
}
return
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ interface AirbyteStreamFactory {

fun createAirbyteStream(discoveredStream: DiscoveredStream): AirbyteStream =
CatalogHelpers.createAirbyteStream(
discoveredStream.name,
discoveredStream.namespace,
discoveredStream.id.name,
discoveredStream.id.namespace,
discoveredStream.columns.map {
AirbyteField.of(it.id, it.type.airbyteType.asJsonSchemaType())
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ class DiscoverOperation(
val namespaces: List<String?> =
listOf<String?>(null) + metadataQuerier.streamNamespaces()
for (namespace in namespaces) {
for (name in metadataQuerier.streamNames(namespace)) {
val fields: List<Field> = metadataQuerier.fields(name, namespace)
for (streamID in metadataQuerier.streamNames(namespace)) {
val fields: List<Field> = metadataQuerier.fields(streamID)
if (fields.isEmpty()) {
log.info {
"Ignoring stream '$name' in '${namespace ?: ""}' because no fields were discovered."
"Ignoring stream '${streamID.name}' in '${namespace ?: ""}' because no fields were discovered."
}
continue
}
val primaryKey: List<List<String>> = metadataQuerier.primaryKey(name, namespace)
val discoveredStream = DiscoveredStream(name, namespace, fields, primaryKey)
val primaryKey: List<List<String>> = metadataQuerier.primaryKey(streamID)
val discoveredStream = DiscoveredStream(streamID, fields, primaryKey)
val airbyteStream: AirbyteStream =
if (config.global) {
airbyteStreamFactory.createGlobal(discoveredStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

package io.airbyte.cdk.discover

import io.airbyte.cdk.StreamIdentifier

data class DiscoveredStream(
val name: String,
val namespace: String?,
val id: StreamIdentifier,
val columns: List<Field>,
val primaryKeyColumnIDs: List<List<String>>,
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* Copyright (c) 2024 Airbyte, Inc., all rights reserved. */
package io.airbyte.cdk.discover

import io.airbyte.cdk.StreamIdentifier
import io.airbyte.cdk.command.SourceConfiguration

/** An abstraction for a catalog discovery session. */
Expand All @@ -10,19 +11,13 @@ interface MetadataQuerier : AutoCloseable {
fun streamNamespaces(): List<String>

/** Returns all available stream names in the given namespace. */
fun streamNames(streamNamespace: String?): List<String>
fun streamNames(streamNamespace: String?): List<StreamIdentifier>

/** Returns all available fields in the given stream. */
fun fields(
streamName: String,
streamNamespace: String?,
): List<Field>
fun fields(streamID: StreamIdentifier): List<Field>

/** Returns the primary key for the given stream, if it exists; empty list otherwise. */
fun primaryKey(
streamName: String,
streamNamespace: String?,
): List<List<String>>
fun primaryKey(streamID: StreamIdentifier): List<List<String>>

/** Executes extra checks which throw a [io.airbyte.cdk.ConfigErrorException] on failure. */
fun extraChecks()
Expand Down
Loading

0 comments on commit 5daf229

Please sign in to comment.