Skip to content

Commit 4f22e24

Browse files
committed
support specifying secret per client
1 parent 26bd421 commit 4f22e24

File tree

4 files changed

+152
-5
lines changed

4 files changed

+152
-5
lines changed

README.md

+11-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,16 @@ func (p radiusService) RadiusHandle(request *radius.Packet) *radius.Packet {
5252
}
5353

5454
func main() {
55-
s := radius.NewServer(":1812", "sEcReT", radiusService{})
55+
s := radius.NewServer(":1812", "secret", radiusService{})
56+
57+
// or you can convert it to a server that accept request
58+
// from some host with different secret
59+
// cls := radius.NewClientList([]radius.Client{
60+
// radius.NewClient("127.0.0.1", "secret1"),
61+
// radius.NewClient("10.10.10.10", "secret2"),
62+
// })
63+
// s.WithClientList(cls)
64+
5665
signalChan := make(chan os.Signal, 1)
5766
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
5867
errChan := make(chan error)
@@ -92,4 +101,4 @@ func main() {
92101
### TODO
93102
* avpEapMessaget.Value error handle.
94103
* implement eap-MSCHAPV2 server side.
95-
* implement radius client side.
104+
* implement radius client side.

client.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package radius
2+
3+
import "sync"
4+
5+
func NewClientList(cs []Client) *ClientList {
6+
cl := new(ClientList)
7+
cl.SetHerd(cs)
8+
return cl
9+
}
10+
11+
// ClientList are list of client allowed to communicate with server
12+
type ClientList struct {
13+
herd map[string]Client
14+
sync.RWMutex
15+
}
16+
17+
// Get client from list of clients based on host
18+
func (cls *ClientList) Get(host string) Client {
19+
cls.RLock()
20+
defer cls.RUnlock()
21+
cl, _ := cls.herd[host]
22+
return cl
23+
}
24+
25+
// Add new client or reset existing client based on host
26+
func (cls *ClientList) AddOrUpdate(cl Client) {
27+
cls.Lock()
28+
defer cls.Unlock()
29+
cls.herd[cl.GetHost()] = cl
30+
}
31+
32+
// Remove client based on host
33+
func (cls *ClientList) Remove(host string) {
34+
cls.Lock()
35+
defer cls.Unlock()
36+
delete(cls.herd, host)
37+
}
38+
39+
// SetHerd reset/initialize the herd of clients
40+
func (cls *ClientList) SetHerd(herd []Client) {
41+
cls.Lock()
42+
defer cls.Unlock()
43+
if cls.herd == nil {
44+
cls.herd = make(map[string]Client)
45+
}
46+
for _, v := range herd {
47+
cls.herd[v.GetHost()] = v
48+
}
49+
}
50+
51+
func (cls *ClientList) GetHerd() []Client {
52+
cls.RLock()
53+
defer cls.RUnlock()
54+
herd := make([]Client, len(cls.herd))
55+
i := 0
56+
for _, v := range cls.herd {
57+
herd[i] = v
58+
i++
59+
}
60+
return herd
61+
}
62+
63+
// Client represent a client to connect to radius server
64+
type Client interface {
65+
// GetHost get the client host
66+
GetHost() string
67+
// GetSecret get shared secret
68+
GetSecret() string
69+
}
70+
71+
// NewClient return new client
72+
func NewClient(host, secret string) Client {
73+
return &DefaultClient{host, secret}
74+
}
75+
76+
// DefaultClient is default client implementation
77+
type DefaultClient struct {
78+
Host string
79+
Secret string
80+
}
81+
82+
// GetSecret get shared secret
83+
func (cl *DefaultClient) GetSecret() string {
84+
return cl.Secret
85+
}
86+
87+
// GetHost get the client host
88+
func (cl *DefaultClient) GetHost() string {
89+
return cl.Host
90+
}

client_test.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package radius
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
func TestClientList(t *testing.T) {
9+
herd := []Client{
10+
NewClient("1.1.1.1", "secret1"),
11+
NewClient("2.2.2.2", "secret2"),
12+
}
13+
cls := NewClientList(herd)
14+
15+
ok(reflect.DeepEqual(cls.GetHerd(), herd))
16+
17+
newClient := NewClient("3.3.3.3", "secret3")
18+
cls.AddOrUpdate(newClient)
19+
ok(reflect.DeepEqual(cls.Get("3.3.3.3"), newClient))
20+
ok(len(cls.GetHerd()) == 3)
21+
22+
updateClient := NewClient("1.1.1.1", "updatesecret")
23+
cls.AddOrUpdate(updateClient)
24+
ok(reflect.DeepEqual(cls.Get("1.1.1.1"), updateClient))
25+
ok(len(cls.GetHerd()) == 3)
26+
27+
cls.Remove("3.3.3.3")
28+
println(cls.Get("3.3.3.3"))
29+
ok(cls.Get("3.3.3.3") == nil)
30+
ok(len(cls.GetHerd()) == 2)
31+
}

server.go

+20-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type Server struct {
1616
service Service
1717
ch chan struct{}
1818
waitGroup *sync.WaitGroup
19-
//services map[string]Service
19+
cl *ClientList
2020
}
2121

2222
type Service interface {
@@ -43,6 +43,11 @@ func NewServer(addr string, secret string, service Service) *Server {
4343
return s
4444
}
4545

46+
// WithClientList set a list of clients that have it's own secret
47+
func (s *Server) WithClientList(cl *ClientList) {
48+
s.cl = cl
49+
}
50+
4651
/*
4752
func (s *Server) RegisterService(serviceAddr string, handler Service) {
4853
s.services[serviceAddr] = handler
@@ -80,8 +85,20 @@ func (s *Server) ListenAndServe() error {
8085
s.waitGroup.Add(1)
8186
go func(p []byte, addr net.Addr) {
8287
defer s.waitGroup.Done()
83-
//fmt.Printf("DecodePacket %#v\n",p)
84-
pac, err := DecodePacket(s.secret, p)
88+
var secret = s.secret
89+
90+
if s.cl != nil {
91+
host, _, err := net.SplitHostPort(addr.String())
92+
if err != nil {
93+
fmt.Println("[pac.Host]", err)
94+
return
95+
}
96+
if cl := s.cl.Get(host); cl != nil {
97+
secret = cl.GetSecret()
98+
}
99+
}
100+
101+
pac, err := DecodePacket(secret, p)
85102
if err != nil {
86103
fmt.Println("[pac.Decode]", err)
87104
return

0 commit comments

Comments
 (0)