-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsocket_test.go
170 lines (157 loc) · 3.49 KB
/
socket_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package requests
import (
"context"
"errors"
"net"
"strings"
"testing"
"time"
)
func TestSocket(t *testing.T) {
// 启动测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
defer listener.Close()
// 在后台接受连接
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
defer conn.Close()
// 简单的回显服务
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
conn.Write(buf[:n])
}()
tests := []struct {
name string
opts []Option
wantErr bool
}{
{
name: "TCP正常连接",
opts: []Option{
URL("tcp://" + listener.Addr().String()),
Timeout(time.Second),
},
wantErr: false,
},
{
name: "无效URL",
opts: []Option{
URL("invalid://localhost"),
Timeout(time.Second),
},
wantErr: true,
},
{
name: "错误URL",
opts: []Option{
URL("://:::"),
Timeout(time.Second),
},
wantErr: true,
},
{
name: "连接超时",
opts: []Option{
URL("tcp://240.0.0.1:12345"), // 不可达的地址
Timeout(1),
},
wantErr: true,
},
{
name: "Unix socket连接",
opts: []Option{
URL("unix:///tmp/test.sock"),
Timeout(time.Second),
},
wantErr: true, // Unix socket文件不存在,应该失败
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
conn, err := Socket(ctx, tt.opts...)
if (err != nil) != tt.wantErr {
t.Errorf("Socket() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
defer conn.Close()
// 测试连接是否可用
_, err = conn.Write([]byte("test"))
if err != nil {
t.Errorf("写入数据失败: %v", err)
}
buf := make([]byte, 4)
_, err = conn.Read(buf)
if err != nil {
t.Errorf("读取数据失败: %v", err)
}
if string(buf) != "test" {
t.Errorf("期望读取到 'test',得到 %s", string(buf))
}
}
})
}
}
func TestSocket_ContextCancel(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
defer listener.Close()
// 在后台接受连接
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
defer conn.Close()
// 简单的回显服务
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
conn.Write(buf[:n])
}()
// 创建一个可取消的上下文
ctx, cancel := context.WithCancel(context.Background())
// 立即取消
cancel()
// 尝试建立连接
if _, err = Socket(ctx, URL("tcp://"+listener.Addr().String())); errors.Is(err, context.Canceled) {
t.Log(err)
return
}
t.Errorf("期望错误为 context.Canceled,得到 %v", err)
}
func TestSocket_WithCustomDialer(t *testing.T) {
// 测试自定义本地地址
localAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // 系统自动分配端口
}
// 启动测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("创建监听器失败: %v", err)
}
defer listener.Close()
// 使用自定义本地地址建立连接
conn, err := Socket(context.Background(),
URL("tcp://"+listener.Addr().String()),
LocalAddr(localAddr),
)
if err != nil {
t.Fatalf("建立连接失败: %v", err)
}
defer conn.Close()
// 验证连接的本地地址
localAddrStr := conn.LocalAddr().String()
if !strings.Contains(localAddrStr, "127.0.0.1") {
t.Errorf("期望本地地址为 127.0.0.1,得到 %s", localAddrStr)
}
}