From 04bafe2cea64cbf814baa690a1ee3dee8ab86fcd Mon Sep 17 00:00:00 2001 From: voidint Date: Tue, 28 Nov 2023 20:38:41 +0800 Subject: [PATCH] misc: add ut --- collector/aliyun/aliyun_collector_test.go | 52 +++++++++++++ collector/official/official_collector_test.go | 73 +++++++++++++++++++ 2 files changed, 125 insertions(+) diff --git a/collector/aliyun/aliyun_collector_test.go b/collector/aliyun/aliyun_collector_test.go index b2c9b41..55968ab 100644 --- a/collector/aliyun/aliyun_collector_test.go +++ b/collector/aliyun/aliyun_collector_test.go @@ -2,11 +2,17 @@ package aliyun import ( "bytes" + "errors" + "fmt" + "net/http" + "net/http/httptest" "os" "testing" "github.com/PuerkitoBio/goquery" + "github.com/agiledragon/gomonkey/v2" "github.com/stretchr/testify/assert" + "github.com/voidint/g/pkg/errs" "github.com/voidint/g/version" ) @@ -80,3 +86,49 @@ func TestCollector_ArchivedVersions(t *testing.T) { assert.Equal(t, []*version.Version{}, vs) }) } + +func TestNewCollector(t *testing.T) { + rr1 := httptest.NewRecorder() + rr1.WriteHeader(http.StatusNotFound) + + rr2 := httptest.NewRecorder() + rr2.WriteHeader(http.StatusOK) + htmlData, err := os.ReadFile("./testdata/golang_dl.html") + assert.Nil(t, err) + rr2.Write(htmlData) + + patches := gomonkey.ApplyMethodSeq(&http.Client{}, "Get", []gomonkey.OutputCell{ + {Values: gomonkey.Params{nil, errors.New("unknown error")}}, + {Values: gomonkey.Params{rr1.Result(), nil}}, + {Values: gomonkey.Params{rr2.Result(), nil}}, + }) + defer patches.Reset() + + tests := []struct { + name string + wantErr error + }{ + { + name: "默认站点URL访问异常", + wantErr: errs.NewURLUnreachableError(DownloadPageURL, errors.New("unknown error")), + }, + { + name: "默认站点URL资源不存在", + wantErr: errs.NewURLUnreachableError(DownloadPageURL, fmt.Errorf("%d", http.StatusNotFound)), + }, + { + name: "默认站点URL访问采集正常", + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewCollector() + assert.Equal(t, tt.wantErr, err) + if err == nil { + assert.NotNil(t, got.pURL) + assert.NotNil(t, got.doc) + } + }) + } +} diff --git a/collector/official/official_collector_test.go b/collector/official/official_collector_test.go index 3a4bc27..0ac766a 100644 --- a/collector/official/official_collector_test.go +++ b/collector/official/official_collector_test.go @@ -2,12 +2,20 @@ package official import ( "bytes" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" "os" + "strings" "testing" "github.com/PuerkitoBio/goquery" + "github.com/agiledragon/gomonkey/v2" "github.com/stretchr/testify/assert" "github.com/voidint/g/pkg/checksum" + "github.com/voidint/g/pkg/errs" "github.com/voidint/g/version" ) @@ -241,3 +249,68 @@ func TestAllVersions(t *testing.T) { assert.Equal(t, 15, len(items[len(items)-1].Packages())) }) } + +func TestNewCollector(t *testing.T) { + t.Run("无效URL", func(t *testing.T) { + var invalidURL strings.Builder + invalidURL.WriteByte(0x7f) + invalidURL.WriteString("hello world") + + c, err := NewCollector(invalidURL.String()) + assert.Nil(t, c) + assert.NotNil(t, err) + e, ok := err.(*url.Error) + assert.True(t, ok) + assert.Equal(t, "parse", e.Op) + assert.Equal(t, invalidURL.String(), e.URL) + assert.NotNil(t, e.Err) + }) + + rr1 := httptest.NewRecorder() + rr1.WriteHeader(http.StatusNotFound) + + rr2 := httptest.NewRecorder() + rr2.WriteHeader(http.StatusOK) + htmlData, err := os.ReadFile("./testdata/golang_dl_with_rc.html") + assert.Nil(t, err) + rr2.Write(htmlData) + + patches := gomonkey.ApplyMethodSeq(&http.Client{}, "Get", []gomonkey.OutputCell{ + {Values: gomonkey.Params{nil, errors.New("unknown error")}}, + {Values: gomonkey.Params{rr1.Result(), nil}}, + {Values: gomonkey.Params{rr2.Result(), nil}}, + }) + defer patches.Reset() + + tests := []struct { + name string + url string + wantErr error + }{ + { + name: "默认站点URL访问异常", + url: "", + wantErr: errs.NewURLUnreachableError(DefaultDownloadPageURL, errors.New("unknown error")), + }, + { + name: "默认站点URL资源不存在", + url: "", + wantErr: errs.NewURLUnreachableError(DefaultDownloadPageURL, fmt.Errorf("%d", http.StatusNotFound)), + }, + { + name: "默认站点URL访问采集正常", + url: "", + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewCollector(tt.url) + assert.Equal(t, tt.wantErr, err) + if err == nil { + assert.NotNil(t, got.pURL) + assert.NotNil(t, got.doc) + } + }) + } +}