Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add an integration test for Artifact API #769

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions tests/artifacts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
//go:build test_integration

/*
Copyright 2024.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package integration

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"testing"
"time"

TestUtil "github.com/opendatahub-io/data-science-pipelines-operator/tests/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func (suite *IntegrationTestSuite) TestFetchArtifacts() {

suite.T().Run("Should successfully fetch and download artifacts", func(t *testing.T) {

// Start port-forwarding
cmd := exec.CommandContext(context.Background(),
"kubectl", "port-forward", "-n", suite.DSPANamespace, "svc/artifact-service", fmt.Sprintf("%d:8080", PortForwardLocalPort))
err := cmd.Start()
require.NoError(t, err, "Failed to start port-forwarding")

// Ensure the port-forwarding process is terminated after the test
defer func() {
_ = cmd.Process.Kill()
cmd.Wait() // Wait for the process to terminate completely
}()

// Wait briefly to ensure port-forwarding is established
time.Sleep(5 * time.Second)

type ResponseArtifact struct {
ArtifactID string `json:"artifact_id"`
DownloadUrl string `json:"download_url"`
}
type ResponseArtifactData struct {
Artifacts []ResponseArtifact `json:"artifacts"`
}

name := "Test Iris Pipeline"
uploadUrl := fmt.Sprintf("%s/apis/v2beta1/pipelines/upload?name=%s", APIServerURL, url.QueryEscape(name))
vals := map[string]string{
"uploadfile": "@resources/iris_pipeline_without_cache_compiled.yaml",
}
bodyUpload, contentTypeUpload := TestUtil.FormFromFile(t, vals)
response, err := suite.Clientmgr.httpClient.Post(uploadUrl, contentTypeUpload, bodyUpload)
require.NoError(t, err)
responseData, err := io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)

// Retrieve Pipeline ID to create a new run
pipelineID, err := TestUtil.RetrievePipelineId(t, suite.Clientmgr.httpClient, APIServerURL, name)
require.NoError(t, err)

// Create a new run
runUrl := fmt.Sprintf("%s/apis/v2beta1/runs", APIServerURL)
bodyRun := TestUtil.FormatRequestBody(t, pipelineID, name)
contentTypeRun := "application/json"
response, err = suite.Clientmgr.httpClient.Post(runUrl, contentTypeRun, bytes.NewReader(bodyRun))
require.NoError(t, err)
responseData, err = io.ReadAll(response.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, response.StatusCode)
err = TestUtil.WaitForPipelineRunCompletion(t, suite.Clientmgr.httpClient, APIServerURL)
require.NoError(t, err)

// fetch artifacts
artifactsUrl := fmt.Sprintf("%s/apis/v2beta1/artifacts?namespace=%s", APIServerURL, suite.DSPANamespace)
response, err = suite.Clientmgr.httpClient.Get(artifactsUrl)
require.NoError(t, err)
responseData, err = io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)

// iterate over the artifacts
var responseArtifactsData ResponseArtifactData
err = json.Unmarshal([]byte(string(responseData)), &responseArtifactsData)
if err != nil {
t.Errorf("Error unmarshaling JSON: %v", err)
return
}
hasDownloadError := false
for _, artifact := range responseArtifactsData.Artifacts {
// get the artifact by ID
artifactsByIdUrl := fmt.Sprintf("%s/apis/v2beta1/artifacts/%s", APIServerURL, artifact.ArtifactID)
response, err = suite.Clientmgr.httpClient.Get(artifactsByIdUrl)
require.NoError(t, err)
responseData, err = io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)

// get download url
artifactsByIdUrl = fmt.Sprintf("%s/apis/v2beta1/artifacts/%s?view=DOWNLOAD", APIServerURL, artifact.ArtifactID)
response, err = suite.Clientmgr.httpClient.Get(artifactsByIdUrl)
require.NoError(t, err)
responseData, err = io.ReadAll(response.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)
loggr.Info(string(responseData))

var responseArtifactData ResponseArtifact
err = json.Unmarshal([]byte(string(responseData)), &responseArtifactData)
if err != nil {
t.Errorf("Error unmarshaling JSON: %v", err)
return
}

content, err := downloadFile(responseArtifactData.DownloadUrl, "/tmp/download", suite.Clientmgr.httpClient)

require.NoError(t, err)
// There were an issue in the past that the URL was returning Access Denied
if strings.Contains(content, "Access Denied") {
hasDownloadError = true
loggr.Error(errors.New("error downloading the artifact"), content)
}
}
if hasDownloadError {
t.Errorf("Error downloading the artifacts. Double check the error messages in the log")
}
})
}

func downloadFile(url, filepath string, httpClient http.Client) (string, error) {
// Create an HTTP GET request to fetch the file from the URL
response, err := httpClient.Get(url)
if err != nil {
return "", fmt.Errorf("failed to fetch the file: %w", err)
}
defer response.Body.Close()

// Check if the response status is OK (200)
if response.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to download file: status code %d", response.StatusCode)
}

// Read the content from the response body
content, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("failed to read content: %w", err)
}

// Create the file
file, err := os.Create(filepath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()

// Write the content to the file
_, err = file.Write(content)
if err != nil {
return "", fmt.Errorf("failed to write content to file: %w", err)
}

return string(content), nil
}
160 changes: 160 additions & 0 deletions tests/resources/iris_pipeline_without_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from kfp import compiler, dsl
from kfp.dsl import ClassificationMetrics, Dataset, Input, Model, Output

common_base_image = (
"registry.redhat.io/ubi8/python-39@sha256:3523b184212e1f2243e76d8094ab52b01ea3015471471290d011625e1763af61"
)
# common_base_image = "quay.io/opendatahub/ds-pipelines-sample-base:v1.0"


@dsl.component(base_image=common_base_image, packages_to_install=["pandas==2.2.0"])
def create_dataset(iris_dataset: Output[Dataset]):
from io import StringIO # noqa: PLC0415

import pandas as pd # noqa: PLC0415

data = """
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
"""
col_names = ["Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Labels"]
df = pd.read_csv(StringIO(data), names=col_names)

with open(iris_dataset.path, "w") as f:
df.to_csv(f)


@dsl.component(
base_image=common_base_image,
packages_to_install=["pandas==2.2.0", "scikit-learn==1.4.0"],
)
def normalize_dataset(
input_iris_dataset: Input[Dataset],
normalized_iris_dataset: Output[Dataset],
standard_scaler: bool,
):
import pandas as pd # noqa: PLC0415
from sklearn.preprocessing import MinMaxScaler, StandardScaler # noqa: PLC0415

with open(input_iris_dataset.path) as f:
df = pd.read_csv(f)
labels = df.pop("Labels")

scaler = StandardScaler() if standard_scaler else MinMaxScaler()

df = pd.DataFrame(scaler.fit_transform(df))
df["Labels"] = labels
normalized_iris_dataset.metadata["state"] = "Normalized"
with open(normalized_iris_dataset.path, "w") as f:
df.to_csv(f)


@dsl.component(
base_image=common_base_image,
packages_to_install=["pandas==2.2.0", "scikit-learn==1.4.0"],
)
def train_model(
normalized_iris_dataset: Input[Dataset],
model: Output[Model],
metrics: Output[ClassificationMetrics],
n_neighbors: int,
):
import pickle # noqa: PLC0415

import pandas as pd # noqa: PLC0415
from sklearn.metrics import confusion_matrix # noqa: PLC0415
from sklearn.model_selection import cross_val_predict, train_test_split # noqa: PLC0415
from sklearn.neighbors import KNeighborsClassifier # noqa: PLC0415

with open(normalized_iris_dataset.path) as f:
df = pd.read_csv(f)

y = df.pop("Labels")
X = df

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) # noqa: F841

clf = KNeighborsClassifier(n_neighbors=n_neighbors)
clf.fit(X_train, y_train)

predictions = cross_val_predict(clf, X_train, y_train, cv=3)
metrics.log_confusion_matrix(
["Iris-Setosa", "Iris-Versicolour", "Iris-Virginica"],
confusion_matrix(y_train, predictions).tolist(), # .tolist() to convert np array to list.
)

model.metadata["framework"] = "scikit-learn"
with open(model.path, "wb") as f:
pickle.dump(clf, f)


@dsl.pipeline(name="iris-training-pipeline")
def my_pipeline(
standard_scaler: bool = True,
neighbors: int = 3,
):
create_dataset_task = create_dataset().set_caching_options(False)

normalize_dataset_task = normalize_dataset(
input_iris_dataset=create_dataset_task.outputs["iris_dataset"], standard_scaler=standard_scaler
).set_caching_options(False)

train_model(
normalized_iris_dataset=normalize_dataset_task.outputs["normalized_iris_dataset"], n_neighbors=neighbors
).set_caching_options(False)


if __name__ == "__main__":
compiler.Compiler().compile(my_pipeline, package_path=__file__.replace(".py", "_compiled.yaml"))

Loading
Loading