diff --git a/internal/signalio/csv.go b/internal/signalio/csv.go index 4a5b3dd05..53bc347c5 100644 --- a/internal/signalio/csv.go +++ b/internal/signalio/csv.go @@ -33,56 +33,83 @@ type csvWriter struct { mu sync.Mutex } -func CSVWriter(w io.Writer, emptySets []signal.Set, extra ...string) Writer { +func CSVWriter(writer io.Writer, emptySets []signal.Set, extra ...string) Writer { return &csvWriter{ header: fieldsFromSignalSets(emptySets, extra), - w: csv.NewWriter(w), + w: csv.NewWriter(writer), } } // WriteSignals implements the Writer interface. -func (w *csvWriter) WriteSignals(signals []signal.Set, extra ...Field) error { +func (writer *csvWriter) WriteSignals(signals []signal.Set, extra ...Field) error { values, err := marshalToMap(signals, extra...) if err != nil { return err } - return w.writeRecord(values) + return writer.writeRecord(values) } -func (w *csvWriter) maybeWriteHeader() error { - // Check headerWritten without the lock to avoid holding the lock if the - // header has already been written. - if w.headerWritten { +func (writer *csvWriter) maybeWriteHeader() error { + /* + The variable writer.headerWritten is checked twice to avoid what is known as a "race condition". + A race condition can occur when two or more goroutines try to access a shared resource + (in this case, the csvWriter instance) concurrently, and the outcome of the program depends on + the interleaving of their execution. + + Imagine the following scenario: + + 1. Goroutine A reads the value of writer.headerWritten as false. + 2. Goroutine B reads the value of writer.headerWritten as false. + 3. Goroutine A acquires the mutex lock and sets writer.headerWritten to true. + 4. Goroutine B acquires the mutex lock and sets writer.headerWritten to true. + + If this happens, the header will be written twice, which is not the desired behavior. + By checking writer.headerWritten twice, once before acquiring the mutex lock and once after acquiring the lock, + the function can ensure that only one goroutine enters the critical section and writes the header. + + Here's how the function works: + + 1. Goroutine A reads the value of writer.headerWritten as false. + 2. Goroutine A acquires the mutex lock. + 3. Goroutine A re-checks the value of writer.headerWritten and finds it to be false. + 4. Goroutine A sets writer.headerWritten to true and writes the header. + 5. Goroutine A releases the mutex lock. + + If Goroutine B tries to enter the critical section at any point after step 2, + it will have to wait until Goroutine A releases the lock in step 5. Once the lock is released, + Goroutine B will re-check the value of writer.headerWritten and find it to be true, + so it will not write the header again. + */ + + if writer.headerWritten { return nil } - // Grab the lock and re-check headerWritten just in case another goroutine - // entered the same critical section. Also prevent concurrent writes to w. - w.mu.Lock() - defer w.mu.Unlock() - if w.headerWritten { + writer.mu.Lock() + defer writer.mu.Unlock() + if writer.headerWritten { return nil } - w.headerWritten = true - return w.w.Write(w.header) + writer.headerWritten = true + return writer.w.Write(writer.header) } -func (w *csvWriter) writeRecord(values map[string]string) error { - if err := w.maybeWriteHeader(); err != nil { +func (writer *csvWriter) writeRecord(values map[string]string) error { + if err := writer.maybeWriteHeader(); err != nil { return err } var rec []string - for _, k := range w.header { + for _, k := range writer.header { rec = append(rec, values[k]) } // Grab the lock when we're ready to write the record to prevent - // concurrent writes to w. - w.mu.Lock() - defer w.mu.Unlock() - if err := w.w.Write(rec); err != nil { + // concurrent writes to writer. + writer.mu.Lock() + defer writer.mu.Unlock() + if err := writer.w.Write(rec); err != nil { return err } - w.w.Flush() - return w.w.Error() + writer.w.Flush() + return writer.w.Error() } func marshalValue(value any) (string, error) { diff --git a/internal/signalio/csv_test.go b/internal/signalio/csv_test.go new file mode 100644 index 000000000..5cd0fb66f --- /dev/null +++ b/internal/signalio/csv_test.go @@ -0,0 +1,219 @@ +package signalio + +import ( + "encoding/csv" + "sync" + "testing" + "time" + + "github.com/ossf/criticality_score/internal/collector/signal" +) + +func TestMarshalValue(t *testing.T) { + tests := []struct { + value any + expected string + wantErr bool + }{ + {value: true, expected: "true", wantErr: false}, + {value: 1, expected: "1", wantErr: false}, + {value: int16(2), expected: "2", wantErr: false}, + {value: int32(3), expected: "3", wantErr: false}, + {value: int64(4), expected: "4", wantErr: false}, + {value: uint(5), expected: "5", wantErr: false}, + {value: uint16(6), expected: "6", wantErr: false}, + {value: uint32(7), expected: "7", wantErr: false}, + {value: uint64(8), expected: "8", wantErr: false}, + {value: byte(9), expected: "9", wantErr: false}, + {value: float32(10.1), expected: "10.1", wantErr: false}, + {value: 11.1, expected: "11.1", wantErr: false}, // float64 + {value: "test", expected: "test", wantErr: false}, + {value: time.Now(), expected: time.Now().Format(time.RFC3339), wantErr: false}, + {value: nil, expected: "", wantErr: false}, + {value: []int{1, 2, 3}, expected: "", wantErr: true}, + {value: map[string]string{"key": "value"}, expected: "", wantErr: true}, + {value: struct{}{}, expected: "", wantErr: true}, + } + for _, test := range tests { + res, err := marshalValue(test.value) + if (err != nil) != test.wantErr { + t.Errorf("Unexpected error for value %v: wantErr %v, got %v", test.value, test.wantErr, err) + } + if res != test.expected { + t.Errorf("Unexpected result for value %v: expected %v, got %v", test.value, test.expected, res) + } + } +} + +func Test_csvWriter_maybeWriteHeader(t *testing.T) { + type fields struct { + w *csv.Writer + header []string + headerWritten bool + } + tests := []struct { + name string + fields fields + }{ + { + name: "write header", + fields: fields{ + w: csv.NewWriter(nil), + header: []string{}, + headerWritten: false, + }, + }, + { + name: "header already written", + fields: fields{ + w: csv.NewWriter(nil), + header: []string{"a", "b", "c"}, + headerWritten: true, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w := &csvWriter{ + w: test.fields.w, + header: test.fields.header, + headerWritten: test.fields.headerWritten, + mu: sync.Mutex{}, + } + if err := w.maybeWriteHeader(); err != nil { // never want an error with these test cases + t.Errorf("maybeWriteHeader() error = %v", err) + } + }) + } +} + +func Test_csvWriter_writeRecord(t *testing.T) { + type fields struct { + w *csv.Writer + header []string + headerWritten bool + } + tests := []struct { //nolint:govet + name string + fields fields + values map[string]string + wantErr bool + }{ + { + name: "write record with regular error", + fields: fields{ + w: csv.NewWriter(&mockWriter{ + written: []byte{'a', 'b', 'c'}, + error: nil, + }), + header: []string{"a", "b", "c"}, + headerWritten: true, + }, + wantErr: true, + }, + { + name: "write record with write error", + fields: fields{ + w: &csv.Writer{}, + header: []string{"a", "b", "c"}, + headerWritten: true, + }, + wantErr: true, + }, + { + name: "write record with maybeWriteHeader error", + fields: fields{ + w: &csv.Writer{}, + header: []string{"a", "b", "c"}, + headerWritten: false, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &csvWriter{ + w: tt.fields.w, + header: tt.fields.header, + headerWritten: tt.fields.headerWritten, + mu: sync.Mutex{}, + } + if err := w.writeRecord(tt.values); (err != nil) != tt.wantErr { + t.Errorf("writeRecord() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +type mockWriter struct { //nolint:govet + written []byte + error error +} + +func (m *mockWriter) Write(p []byte) (n int, err error) { + return 0, m.error +} + +func Test_csvWriter_WriteSignals(t *testing.T) { + type args struct { + signals []signal.Set + extra []Field + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "write signals with marshal error", + args: args{ + signals: []signal.Set{ + &testSet{ + UpdatedCount: signal.Val(1), + }, + }, + extra: []Field{ + { + Key: "a", + Value: []int{1, 2, 3}, + }, + { + Key: "b", + Value: map[string]string{"key": "value"}, + }, + }, + }, + wantErr: true, + }, + { + name: "write signals with write error", + args: args{ + extra: []Field{ + { + Key: "a", + Value: "1", + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writer := CSVWriter(&mockWriter{}, []signal.Set{}, "a", "b") + + if err := writer.WriteSignals(tt.args.signals, tt.args.extra...); (err != nil) != tt.wantErr { + t.Errorf("WriteSignals() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +type testSet struct { //nolint:govet + UpdatedCount signal.Field[int] + Field string +} + +func (t testSet) Namespace() signal.Namespace { + return "test" +}