diff --git a/config/config.ini b/config/config.ini index f63377db..65de6012 100644 --- a/config/config.ini +++ b/config/config.ini @@ -7,4 +7,4 @@ updateUrl = https://downloads.arduino.cc/ origins = https://local.arduino.cc:8000 #httpProxy = http://your.proxy:port # Proxy server for HTTP requests crashreport = false # enable crashreport logging -autostartMacOS = true # the Arduino Create Agent is able to start automatically after login on macOS (launchd agent) \ No newline at end of file +autostartMacOS = true # the Arduino Create Agent is able to start automatically after login on macOS (launchd agent) diff --git a/conn.go b/conn.go index 8c71c54c..42b573bb 100644 --- a/conn.go +++ b/conn.go @@ -27,6 +27,7 @@ import ( "os" "path/filepath" + "github.com/arduino/arduino-create-agent/tools" "github.com/arduino/arduino-create-agent/upload" "github.com/arduino/arduino-create-agent/utilities" "github.com/gin-gonic/gin" @@ -80,7 +81,7 @@ type Upload struct { var uploadStatusStr = "ProgrammerStatus" -func uploadHandler(pubKey *rsa.PublicKey) func(*gin.Context) { +func uploadHandler(hub *hub, pubKey *rsa.PublicKey, tools *tools.Tools) func(*gin.Context) { return func(c *gin.Context) { data := new(Upload) if err := c.BindJSON(data); err != nil { @@ -162,28 +163,28 @@ func uploadHandler(pubKey *rsa.PublicKey) func(*gin.Context) { go func() { // Resolve commandline - commandline, err := upload.PartiallyResolve(data.Board, filePath, tmpdir, data.Commandline, data.Extra, Tools) + commandline, err := upload.PartiallyResolve(data.Board, filePath, tmpdir, data.Commandline, data.Extra, tools) if err != nil { - send(map[string]string{uploadStatusStr: "Error", "Msg": err.Error()}) + send(hub, map[string]string{uploadStatusStr: "Error", "Msg": err.Error()}) return } - l := PLogger{Verbose: true} + l := PLogger{Verbose: true, hub: hub} // Upload if data.Extra.Network { err = errors.New("network upload is not supported anymore, pease use OTA instead") } else { - send(map[string]string{uploadStatusStr: "Starting", "Cmd": "Serial"}) + send(hub, map[string]string{uploadStatusStr: "Starting", "Cmd": "Serial"}) err = upload.Serial(data.Port, commandline, data.Extra, l) } // Handle result if err != nil { - send(map[string]string{uploadStatusStr: "Error", "Msg": err.Error()}) + send(hub, map[string]string{uploadStatusStr: "Error", "Msg": err.Error()}) return } - send(map[string]string{uploadStatusStr: "Done", "Flash": "Ok"}) + send(hub, map[string]string{uploadStatusStr: "Done", "Flash": "Ok"}) }() c.String(http.StatusAccepted, "") @@ -193,6 +194,7 @@ func uploadHandler(pubKey *rsa.PublicKey) func(*gin.Context) { // PLogger sends the info from the upload to the websocket type PLogger struct { Verbose bool + hub *hub } // Debug only sends messages if verbose is true (always true for now) @@ -206,15 +208,15 @@ func (l PLogger) Debug(args ...interface{}) { func (l PLogger) Info(args ...interface{}) { output := fmt.Sprint(args...) log.Println(output) - send(map[string]string{uploadStatusStr: "Busy", "Msg": output}) + send(l.hub, map[string]string{uploadStatusStr: "Busy", "Msg": output}) } -func send(args map[string]string) { +func send(hub *hub, args map[string]string) { mapB, _ := json.Marshal(args) - h.broadcastSys <- mapB + hub.broadcastSys <- mapB } -func wsHandler() *WsServer { +func wsHandler(hub *hub) *WsServer { server, err := socketio.NewServer(nil) if err != nil { log.Fatal(err) @@ -222,13 +224,13 @@ func wsHandler() *WsServer { server.On("connection", func(so socketio.Socket) { c := &connection{send: make(chan []byte, 256*10), ws: so} - h.register <- c + hub.register <- c so.On("command", func(message string) { - h.broadcast <- []byte(message) + hub.broadcast <- []byte(message) }) so.On("disconnection", func() { - h.unregister <- c + hub.unregister <- c }) go c.writer() }) diff --git a/hub.go b/hub.go index a162dd01..9d15240f 100755 --- a/hub.go +++ b/hub.go @@ -16,6 +16,7 @@ package main import ( + "bytes" "encoding/json" "fmt" "html" @@ -26,8 +27,11 @@ import ( "strconv" "strings" + "github.com/arduino/arduino-create-agent/systray" + "github.com/arduino/arduino-create-agent/tools" "github.com/arduino/arduino-create-agent/upload" log "github.com/sirupsen/logrus" + "go.bug.st/serial" ) type hub struct { @@ -45,14 +49,47 @@ type hub struct { // Unregister requests from connections. unregister chan *connection + + // Serial hub to communicate with serial ports + serialHub *serialhub + + serialPortList *serialPortList + + tools *tools.Tools + + systray *systray.Systray } -var h = hub{ - broadcast: make(chan []byte, 1000), - broadcastSys: make(chan []byte, 1000), - register: make(chan *connection), - unregister: make(chan *connection), - connections: make(map[*connection]bool), +func newHub(tools *tools.Tools, systray *systray.Systray) *hub { + broadcastSys := make(chan []byte, 1000) + + onRegister := func(port *serport) { + broadcastSys <- []byte("{\"Cmd\":\"Open\",\"Desc\":\"Got register/open on port.\",\"Port\":\"" + port.portConf.Name + "\",\"Baud\":" + strconv.Itoa(port.portConf.Baud) + ",\"BufferType\":\"" + port.BufferType + "\"}") + } + onUnregister := func(port *serport) { + broadcastSys <- []byte("{\"Cmd\":\"Close\",\"Desc\":\"Got unregister/close on port.\",\"Port\":\"" + port.portConf.Name + "\",\"Baud\":" + strconv.Itoa(port.portConf.Baud) + "}") + } + serialHub := newSerialHub(onRegister, onUnregister) + + onList := func(data []byte) { + broadcastSys <- data + } + onErr := func(err string) { + broadcastSys <- []byte("{\"Error\":\"" + err + "\"}") + } + serialPortList := newSerialPortList(tools, onList, onErr) + + return &hub{ + broadcast: make(chan []byte, 1000), + broadcastSys: broadcastSys, + register: make(chan *connection), + unregister: make(chan *connection), + connections: make(map[*connection]bool), + serialHub: serialHub, + serialPortList: serialPortList, + tools: tools, + systray: systray, + } } const commands = `{ @@ -73,51 +110,53 @@ const commands = `{ ] }` -func (h *hub) unregisterConnection(c *connection) { - if _, contains := h.connections[c]; !contains { +func (hub *hub) unregisterConnection(c *connection) { + if _, contains := hub.connections[c]; !contains { return } - delete(h.connections, c) + delete(hub.connections, c) close(c.send) } -func (h *hub) sendToRegisteredConnections(data []byte) { - for c := range h.connections { +func (hub *hub) sendToRegisteredConnections(data []byte) { + for c := range hub.connections { select { case c.send <- data: //log.Print("did broadcast to ") //log.Print(c.ws.RemoteAddr()) //c.send <- []byte("hello world") default: - h.unregisterConnection(c) + hub.unregisterConnection(c) } } } -func (h *hub) run() { +func (hub *hub) run() { + go hub.serialPortList.Run() + for { select { - case c := <-h.register: - h.connections[c] = true + case c := <-hub.register: + hub.connections[c] = true // send supported commands c.send <- []byte(fmt.Sprintf(`{"Version" : "%s"} `, version)) c.send <- []byte(html.EscapeString(commands)) c.send <- []byte(fmt.Sprintf(`{"Hostname" : "%s"} `, *hostname)) c.send <- []byte(fmt.Sprintf(`{"OS" : "%s"} `, runtime.GOOS)) - case c := <-h.unregister: - h.unregisterConnection(c) - case m := <-h.broadcast: + case c := <-hub.unregister: + hub.unregisterConnection(c) + case m := <-hub.broadcast: if len(m) > 0 { - checkCmd(m) - h.sendToRegisteredConnections(m) + hub.checkCmd(m) + hub.sendToRegisteredConnections(m) } - case m := <-h.broadcastSys: - h.sendToRegisteredConnections(m) + case m := <-hub.broadcastSys: + hub.sendToRegisteredConnections(m) } } } -func checkCmd(m []byte) { +func (hub *hub) checkCmd(m []byte) { //log.Print("Inside checkCmd") s := string(m[:]) @@ -132,18 +171,18 @@ func checkCmd(m []byte) { args := strings.Split(s, " ") if len(args) < 3 { - go spErr("You did not specify a port and baud rate in your open cmd") + go hub.spErr("You did not specify a port and baud rate in your open cmd") return } if len(args[1]) < 1 { - go spErr("You did not specify a serial port") + go hub.spErr("You did not specify a serial port") return } baudStr := strings.Replace(args[2], "\n", "", -1) baud, err := strconv.Atoi(baudStr) if err != nil { - go spErr("Problem converting baud rate " + args[2]) + go hub.spErr("Problem converting baud rate " + args[2]) return } // pass in buffer type now as string. if user does not @@ -154,30 +193,30 @@ func checkCmd(m []byte) { buftype := strings.Replace(args[3], "\n", "", -1) bufferAlgorithm = buftype } - go spHandlerOpen(args[1], baud, bufferAlgorithm) + go hub.spHandlerOpen(args[1], baud, bufferAlgorithm) } else if strings.HasPrefix(sl, "close") { args := strings.Split(s, " ") if len(args) > 1 { - go spClose(args[1]) + go hub.spClose(args[1]) } else { - go spErr("You did not specify a port to close") + go hub.spErr("You did not specify a port to close") } } else if strings.HasPrefix(sl, "killupload") { // kill the running process (assumes singleton for now) go func() { upload.Kill() - h.broadcastSys <- []byte("{\"uploadStatus\": \"Killed\"}") + hub.broadcastSys <- []byte("{\"uploadStatus\": \"Killed\"}") log.Println("{\"uploadStatus\": \"Killed\"}") }() } else if strings.HasPrefix(sl, "send") { // will catch send and sendnobuf and sendraw - go spWrite(s) + go hub.spWrite(s) } else if strings.HasPrefix(sl, "list") { - go serialPorts.List() + go hub.serialPortList.List() } else if strings.HasPrefix(sl, "downloadtool") { go func() { args := strings.Split(s, " ") @@ -188,7 +227,7 @@ func checkCmd(m []byte) { if len(args) <= 1 { mapD := map[string]string{"DownloadStatus": "Error", "Msg": "Not enough arguments"} mapB, _ := json.Marshal(mapD) - h.broadcastSys <- mapB + hub.broadcastSys <- mapB return } if len(args) > 1 { @@ -208,75 +247,234 @@ func checkCmd(m []byte) { behaviour = args[4] } - err := Tools.Download(pack, tool, toolVersion, behaviour) + reportPendingProgress := func(msg string) { + mapD := map[string]string{"DownloadStatus": "Pending", "Msg": msg} + mapB, _ := json.Marshal(mapD) + hub.broadcastSys <- mapB + } + err := hub.tools.Download(pack, tool, toolVersion, behaviour, reportPendingProgress) if err != nil { mapD := map[string]string{"DownloadStatus": "Error", "Msg": err.Error()} mapB, _ := json.Marshal(mapD) - h.broadcastSys <- mapB + hub.broadcastSys <- mapB } else { mapD := map[string]string{"DownloadStatus": "Success", "Msg": "Map Updated"} mapB, _ := json.Marshal(mapD) - h.broadcastSys <- mapB + hub.broadcastSys <- mapB } }() } else if strings.HasPrefix(sl, "log") { - go logAction(sl) + go hub.logAction(sl) } else if strings.HasPrefix(sl, "restart") { + // potentially, the sysStray dependencies can be removed https://github.com/arduino/arduino-create-agent/issues/1013 log.Println("Received restart from the daemon. Why? Boh") - Systray.Restart() + hub.systray.Restart() } else if strings.HasPrefix(sl, "exit") { - Systray.Quit() + hub.systray.Quit() } else if strings.HasPrefix(sl, "memstats") { - memoryStats() + hub.memoryStats() } else if strings.HasPrefix(sl, "gc") { - garbageCollection() + hub.garbageCollection() } else if strings.HasPrefix(sl, "hostname") { - getHostname() + hub.getHostname() } else if strings.HasPrefix(sl, "version") { - getVersion() + hub.getVersion() } else { - go spErr("Could not understand command.") + go hub.spErr("Could not understand command.") } } -func logAction(sl string) { +type logWriter struct { + onWrite func([]byte) +} + +func (hub *hub) logAction(sl string) { if strings.HasPrefix(sl, "log on") { *logDump = "on" - multiWriter := io.MultiWriter(&loggerWs, os.Stderr) + + logWriter := logWriter{} + logWriter.onWrite = func(p []byte) { + hub.broadcastSys <- p + } + + multiWriter := io.MultiWriter(&logWriter, os.Stderr) log.SetOutput(multiWriter) } else if strings.HasPrefix(sl, "log off") { *logDump = "off" log.SetOutput(os.Stderr) // } else if strings.HasPrefix(sl, "log show") { // TODO: send all the saved log to websocket - //h.broadcastSys <- []byte("{\"BufFlowDebug\" : \"" + *logDump + "\"}") + //hub.broadcastSys <- []byte("{\"BufFlowDebug\" : \"" + *logDump + "\"}") } } -func memoryStats() { +func (u *logWriter) Write(p []byte) (n int, err error) { + u.onWrite(p) + return len(p), nil +} + +func (hub *hub) memoryStats() { var memStats runtime.MemStats runtime.ReadMemStats(&memStats) json, _ := json.Marshal(memStats) log.Printf("memStats:%v\n", string(json)) - h.broadcastSys <- json + hub.broadcastSys <- json } -func getHostname() { - h.broadcastSys <- []byte("{\"Hostname\" : \"" + *hostname + "\"}") +func (hub *hub) getHostname() { + hub.broadcastSys <- []byte("{\"Hostname\" : \"" + *hostname + "\"}") } -func getVersion() { - h.broadcastSys <- []byte("{\"Version\" : \"" + version + "\"}") +func (hub *hub) getVersion() { + hub.broadcastSys <- []byte("{\"Version\" : \"" + version + "\"}") } -func garbageCollection() { +func (hub *hub) garbageCollection() { log.Printf("Starting garbageCollection()\n") - h.broadcastSys <- []byte("{\"gc\":\"starting\"}") - memoryStats() + hub.broadcastSys <- []byte("{\"gc\":\"starting\"}") + hub.memoryStats() debug.SetGCPercent(100) debug.FreeOSMemory() debug.SetGCPercent(-1) log.Printf("Done with garbageCollection()\n") - h.broadcastSys <- []byte("{\"gc\":\"done\"}") - memoryStats() + hub.broadcastSys <- []byte("{\"gc\":\"done\"}") + hub.memoryStats() +} + +func (hub *hub) spHandlerOpen(portname string, baud int, buftype string) { + + log.Print("Inside spHandler") + + var out bytes.Buffer + + out.WriteString("Opening serial port ") + out.WriteString(portname) + out.WriteString(" at ") + out.WriteString(strconv.Itoa(baud)) + out.WriteString(" baud") + log.Print(out.String()) + + conf := &SerialConfig{Name: portname, Baud: baud, RtsOn: true} + + mode := &serial.Mode{ + BaudRate: baud, + } + + sp, err := serial.Open(portname, mode) + log.Print("Just tried to open port") + if err != nil { + //log.Fatal(err) + log.Print("Error opening port " + err.Error()) + //hub.broadcastSys <- []byte("Error opening port. " + err.Error()) + hub.broadcastSys <- []byte("{\"Cmd\":\"OpenFail\",\"Desc\":\"Error opening port. " + err.Error() + "\",\"Port\":\"" + conf.Name + "\",\"Baud\":" + strconv.Itoa(conf.Baud) + "}") + + return + } + log.Print("Opened port successfully") + //p := &serport{send: make(chan []byte, 256), portConf: conf, portIo: sp} + // we can go up to 256,000 lines of gcode in the buffer + p := &serport{ + sendBuffered: make(chan string, 256000), + sendNoBuf: make(chan []byte), + sendRaw: make(chan string), + portConf: conf, + portIo: sp, + portName: portname, + BufferType: buftype, + } + + p.OnMessage = func(msg []byte) { + hub.broadcastSys <- msg + } + p.OnClose = func(port *serport) { + hub.serialPortList.MarkPortAsClosed(p.portName) + hub.serialPortList.List() + } + + var bw Bufferflow + + switch buftype { + case "timed": + bw = NewBufferflowTimed(portname, hub.broadcastSys) + case "timedraw": + bw = NewBufferflowTimedRaw(portname, hub.broadcastSys) + case "default": + bw = NewBufferflowDefault(portname, hub.broadcastSys) + default: + log.Panicf("unknown buffer type: %s", buftype) + } + + bw.Init() + p.bufferwatcher = bw + + hub.serialHub.Register(p) + defer hub.serialHub.Unregister(p) + + hub.serialPortList.MarkPortAsOpened(portname) + hub.serialPortList.List() + + // this is internally buffered thread to not send to serial port if blocked + go p.writerBuffered() + // this is thread to send to serial port regardless of block + go p.writerNoBuf() + // this is thread to send to serial port but with base64 decoding + go p.writerRaw() + + p.reader(buftype) + + hub.serialPortList.List() +} + +func (hub *hub) spClose(portname string) { + if myport, ok := hub.serialHub.FindPortByName(portname); ok { + hub.broadcastSys <- []byte("Closing serial port " + portname) + myport.Close() + } else { + hub.spErr("We could not find the serial port " + portname + " that you were trying to close.") + } +} + +func (hub *hub) spWrite(arg string) { + // we will get a string of comXX asdf asdf asdf + //log.Println("Inside spWrite arg: " + arg) + arg = strings.TrimPrefix(arg, " ") + //log.Println("arg after trim: " + arg) + args := strings.SplitN(arg, " ", 3) + if len(args) != 3 { + errstr := "Could not parse send command: " + arg + //log.Println(errstr) + hub.spErr(errstr) + return + } + bufferingMode := args[0] + portname := strings.Trim(args[1], " ") + data := args[2] + + //log.Println("The port to write to is:" + portname + "---") + //log.Println("The data is:" + data + "---") + + // See if we have this port open + port, ok := hub.serialHub.FindPortByName(portname) + if !ok { + // we couldn't find the port, so send err + hub.spErr("We could not find the serial port " + portname + " that you were trying to write to.") + return + } + + // see if bufferingMode is valid + switch bufferingMode { + case "send", "sendnobuf", "sendraw": + // valid buffering mode, go ahead + default: + hub.spErr("Unsupported send command:" + args[0] + ". Please specify a valid one") + return + } + + // send it to the write channel + port.Write(data, bufferingMode) +} + +func (hub *hub) spErr(err string) { + //log.Println("Sending err back: ", err) + hub.broadcastSys <- []byte("{\"Error\" : \"" + err + "\"}") } diff --git a/info.go b/info.go index 88145c02..2d629671 100644 --- a/info.go +++ b/info.go @@ -19,6 +19,7 @@ import ( "runtime" "strings" + "github.com/arduino/arduino-create-agent/systray" "github.com/gin-gonic/gin" "go.bug.st/serial" ) @@ -40,14 +41,16 @@ func infoHandler(c *gin.Context) { }) } -func pauseHandler(c *gin.Context) { - go func() { - ports, _ := serial.GetPortsList() - for _, element := range ports { - spClose(element) - } - *hibernate = true - Systray.Pause() - }() - c.JSON(200, nil) +func pauseHandler(hub *hub, s *systray.Systray) func(c *gin.Context) { + return func(c *gin.Context) { + go func() { + ports, _ := serial.GetPortsList() + for _, element := range ports { + hub.spClose(element) + } + *hibernate = true + s.Pause() + }() + c.JSON(200, nil) + } } diff --git a/main.go b/main.go index 41f824b1..51a26450 100755 --- a/main.go +++ b/main.go @@ -20,7 +20,6 @@ package main import ( _ "embed" - "encoding/json" "flag" "html/template" "io" @@ -100,22 +99,6 @@ var homeTemplate = template.Must(template.New("home").Parse(homeTemplateHTML)) //go:embed home.html var homeTemplateHTML string -// global clients -var ( - Tools *tools.Tools - Systray systray.Systray - Index *index.Resource -) - -type logWriter struct{} - -func (u *logWriter) Write(p []byte) (n int, err error) { - h.broadcastSys <- p - return len(p), nil -} - -var loggerWs logWriter - func homeHandler(c *gin.Context) { homeTemplate.Execute(c.Writer, c.Request.Host) } @@ -141,12 +124,9 @@ func main() { // Check if certificates made with Agent <=1.2.7 needs to be moved over the new location cert.MigrateCertificatesGeneratedWithOldAgentVersions(config.GetCertificatesDir()) - // Launch main loop in a goroutine - go loop() - // SetupSystray is the main thread configDir := config.GetDefaultConfigDir() - Systray = systray.Systray{ + stray := systray.Systray{ Hibernate: *hibernate, Version: version + "-" + commit, DebugURL: func() string { @@ -156,16 +136,19 @@ func main() { ConfigDir: configDir, } + // Launch main loop in a goroutine + go loop(&stray) + if src, err := os.Executable(); err != nil { panic(err) } else if restartPath := updater.Start(src); restartPath != "" { - Systray.RestartWith(restartPath) + stray.RestartWith(restartPath) } else { - Systray.Start() + stray.Start() } } -func loop() { +func loop(stray *systray.Systray) { if *hibernate { return } @@ -182,11 +165,18 @@ func loop() { os.Exit(0) } - logger := func(msg string) { - mapD := map[string]string{"DownloadStatus": "Pending", "Msg": msg} - mapB, _ := json.Marshal(mapD) - h.broadcastSys <- mapB + // Instantiate Index and Tools + index := index.Init(*indexURL, config.GetDataDir()) + if signatureKey == nil || len(*signatureKey) == 0 { + log.Panicf("signature public key should be set") } + signaturePubKey, err := utilities.ParseRsaPublicKey([]byte(*signatureKey)) + if err != nil { + log.Panicf("cannot parse signature key '%s'. %s", *signatureKey, err) + } + tools := tools.New(config.GetDataDir(), index, signaturePubKey) + + hub := newHub(tools, stray) // Let's handle the config configDir := config.GetDefaultConfigDir() @@ -258,7 +248,7 @@ func loop() { if err != nil { log.Panicf("cannot parse arguments: %s", err) } - Systray.SetCurrentConfigFile(configPath) + stray.SetCurrentConfigFile(configPath) // Parse additional ini config if defined if len(*additionalConfig) > 0 { @@ -278,18 +268,6 @@ func loop() { } } - if signatureKey == nil || len(*signatureKey) == 0 { - log.Panicf("signature public key should be set") - } - signaturePubKey, err := utilities.ParseRsaPublicKey([]byte(*signatureKey)) - if err != nil { - log.Panicf("cannot parse signature key '%s'. %s", *signatureKey, err) - } - - // Instantiate Index and Tools - Index = index.Init(*indexURL, config.GetDataDir()) - Tools = tools.New(config.GetDataDir(), Index, logger, signaturePubKey) - // see if we are supposed to wait 5 seconds if *isLaunchSelf { launchSelfLater() @@ -414,16 +392,14 @@ func loop() { } } - // launch the discoveries for the running system - go serialPorts.Run() // launch the hub routine which is the singleton for the websocket server - go h.run() + go hub.run() // launch our dummy data routine //go d.run() r := gin.New() - socketHandler := wsHandler().ServeHTTP + socketHandler := wsHandler(hub).ServeHTTP extraOrigins := []string{ "https://create.arduino.cc", @@ -462,17 +438,17 @@ func loop() { r.LoadHTMLFiles("templates/nofirefox.html") r.GET("/", homeHandler) - r.POST("/upload", uploadHandler(signaturePubKey)) + r.POST("/upload", uploadHandler(hub, signaturePubKey, tools)) r.GET("/socket.io/", socketHandler) r.POST("/socket.io/", socketHandler) r.Handle("WS", "/socket.io/", socketHandler) r.Handle("WSS", "/socket.io/", socketHandler) r.GET("/info", infoHandler) - r.POST("/pause", pauseHandler) - r.POST("/update", updateHandler) + r.POST("/pause", pauseHandler(hub, stray)) + r.POST("/update", updateHandler(stray)) // Mount goa handlers - goa := v2.Server(config.GetDataDir().String(), Index, signaturePubKey) + goa := v2.Server(config.GetDataDir().String(), index, signaturePubKey) r.Any("/v2/*path", gin.WrapH(goa)) go func() { diff --git a/main_test.go b/main_test.go index 1387fd22..4b0720e0 100644 --- a/main_test.go +++ b/main_test.go @@ -29,9 +29,11 @@ import ( "testing" "github.com/arduino/arduino-create-agent/config" - "github.com/arduino/arduino-create-agent/gen/tools" + genTools "github.com/arduino/arduino-create-agent/gen/tools" "github.com/arduino/arduino-create-agent/globals" "github.com/arduino/arduino-create-agent/index" + "github.com/arduino/arduino-create-agent/systray" + "github.com/arduino/arduino-create-agent/tools" "github.com/arduino/arduino-create-agent/upload" "github.com/arduino/arduino-create-agent/utilities" v2 "github.com/arduino/arduino-create-agent/v2" @@ -56,7 +58,15 @@ func TestValidSignatureKey(t *testing.T) { func TestUploadHandlerAgainstEvilFileNames(t *testing.T) { r := gin.New() - r.POST("/", uploadHandler(utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey)))) + + index := index.Init(*indexURL, config.GetDataDir()) + signaturePubKey, err := utilities.ParseRsaPublicKey([]byte(*signatureKey)) + require.NoError(t, err) + tools := tools.New(config.GetDataDir(), index, signaturePubKey) + hub := newHub(tools, &systray.Systray{}) + pubkey := utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey)) + + r.POST("/", uploadHandler(hub, pubkey, tools)) ts := httptest.NewServer(r) uploadEvilFileName := Upload{ @@ -92,7 +102,15 @@ func TestUploadHandlerAgainstEvilFileNames(t *testing.T) { func TestUploadHandlerAgainstBase64WithoutPaddingMustFail(t *testing.T) { r := gin.New() - r.POST("/", uploadHandler(utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey)))) + + index := index.Init(*indexURL, config.GetDataDir()) + signaturePubKey, err := utilities.ParseRsaPublicKey([]byte(*signatureKey)) + require.NoError(t, err) + tools := tools.New(config.GetDataDir(), index, signaturePubKey) + hub := newHub(tools, &systray.Systray{}) + pubkey := utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey)) + + r.POST("/", uploadHandler(hub, pubkey, tools)) ts := httptest.NewServer(r) defer ts.Close() @@ -126,7 +144,7 @@ func TestInstallToolV2(t *testing.T) { ts := httptest.NewServer(r) type test struct { - request tools.ToolPayload + request genTools.ToolPayload responseCode int responseBody string } @@ -134,7 +152,7 @@ func TestInstallToolV2(t *testing.T) { bossacURL := "http://downloads.arduino.cc/tools/bossac-1.7.0-arduino3-linux64.tar.gz" bossacChecksum := "SHA-256:1ae54999c1f97234a5c603eb99ad39313b11746a4ca517269a9285afa05f9100" bossacSignature := "382898a97b5a86edd74208f10107d2fecbf7059ffe9cc856e045266fb4db4e98802728a0859cfdcda1c0b9075ec01e42dbea1f430b813530d5a6ae1766dfbba64c3e689b59758062dc2ab2e32b2a3491dc2b9a80b9cda4ae514fbe0ec5af210111b6896976053ab76bac55bcecfcececa68adfa3299e3cde6b7f117b3552a7d80ca419374bb497e3c3f12b640cf5b20875416b45e662fc6150b99b178f8e41d6982b4c0a255925ea39773683f9aa9201dc5768b6fc857c87ff602b6a93452a541b8ec10ca07f166e61a9e9d91f0a6090bd2038ed4427af6251039fb9fe8eb62ec30d7b0f3df38bc9de7204dec478fb86f8eb3f71543710790ee169dce039d3e0" - bossacInstallURLOK := tools.ToolPayload{ + bossacInstallURLOK := genTools.ToolPayload{ Name: "bossac", Version: "1.7.0-arduino3", Packager: "arduino", @@ -146,7 +164,7 @@ func TestInstallToolV2(t *testing.T) { esptoolURL := "https://github.com/earlephilhower/esp-quick-toolchain/releases/download/2.5.0-3/x86_64-linux-gnu.esptool-f80ae31.tar.gz" esptoolChecksum := "SHA-256:bded1dca953377838b6086a9bcd40a1dc5286ba5f69f2372c22a1d1819baad24" esptoolSignature := "852b58871419ce5e5633ecfaa72c0f0fa890ceb51164b362b8133bc0e3e003a21cec48935b8cdc078f4031219cbf17fb7edd9d7c9ca8ed85492911c9ca6353c9aa4691eb91fda99563a6bd49aeca0d9981fb05ec76e45c6024f8a6822862ad1e34ddc652fbbf4fa909887a255d4f087398ec386577efcec523c21203be3d10fc9e9b0f990a7536875a77dc2bc5cbffea7734b62238e31719111b718bacccebffc9be689545540e81d23b81caa66214376f58a0d6a45cf7efc5d3af62ab932b371628162fffe403906f41d5534921e5be081c5ac2ecc9db5caec03a105cc44b00ce19a95ad079843501eb8182e0717ce327867380c0e39d2b48698547fc1d0d66" - esptoolInstallURLOK := tools.ToolPayload{ + esptoolInstallURLOK := genTools.ToolPayload{ Name: "esptool", Version: "2.5.0-3-20ed2b9", Packager: "esp8266", @@ -156,7 +174,7 @@ func TestInstallToolV2(t *testing.T) { } wrongSignature := "wr0ngs1gn4tur3" - bossacInstallWrongSig := tools.ToolPayload{ + bossacInstallWrongSig := genTools.ToolPayload{ Name: "bossac", Version: "1.7.0-arduino3", Packager: "arduino", @@ -166,7 +184,7 @@ func TestInstallToolV2(t *testing.T) { } wrongChecksum := "wr0ngch3cksum" - bossacInstallWrongCheck := tools.ToolPayload{ + bossacInstallWrongCheck := genTools.ToolPayload{ Name: "bossac", Version: "1.7.0-arduino3", Packager: "arduino", @@ -175,7 +193,7 @@ func TestInstallToolV2(t *testing.T) { Signature: &bossacSignature, } - bossacInstallNoURL := tools.ToolPayload{ + bossacInstallNoURL := genTools.ToolPayload{ Name: "bossac", Version: "1.7.0-arduino3", Packager: "arduino", diff --git a/serial.go b/serial.go deleted file mode 100755 index 64e5b8f7..00000000 --- a/serial.go +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright 2022 Arduino SA -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published -// by the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -// Supports Windows, Linux, Mac, BeagleBone Black, and Raspberry Pi - -package main - -import ( - "encoding/json" - "slices" - "strconv" - "strings" - "sync" - "time" - - discovery "github.com/arduino/pluggable-discovery-protocol-handler/v2" - "github.com/sirupsen/logrus" -) - -type serialhub struct { - // Opened serial ports. - ports map[*serport]bool - - mu sync.Mutex -} - -// SerialPortList is the serial port list -type SerialPortList struct { - Ports []*SpPortItem - portsLock sync.Mutex -} - -// SpPortItem is the serial port item -type SpPortItem struct { - Name string - SerialNumber string - DeviceClass string - IsOpen bool - IsPrimary bool - Baud int - BufferAlgorithm string - Ver string - VendorID string - ProductID string -} - -// serialPorts contains the ports attached to the machine -var serialPorts SerialPortList - -var sh = serialhub{ - ports: make(map[*serport]bool), -} - -// Register serial ports from the connections. -func (sh *serialhub) Register(port *serport) { - sh.mu.Lock() - //log.Print("Registering a port: ", p.portConf.Name) - h.broadcastSys <- []byte("{\"Cmd\":\"Open\",\"Desc\":\"Got register/open on port.\",\"Port\":\"" + port.portConf.Name + "\",\"Baud\":" + strconv.Itoa(port.portConf.Baud) + ",\"BufferType\":\"" + port.BufferType + "\"}") - sh.ports[port] = true - sh.mu.Unlock() -} - -// Unregister requests from connections. -func (sh *serialhub) Unregister(port *serport) { - sh.mu.Lock() - //log.Print("Unregistering a port: ", p.portConf.Name) - h.broadcastSys <- []byte("{\"Cmd\":\"Close\",\"Desc\":\"Got unregister/close on port.\",\"Port\":\"" + port.portConf.Name + "\",\"Baud\":" + strconv.Itoa(port.portConf.Baud) + "}") - delete(sh.ports, port) - close(port.sendBuffered) - close(port.sendNoBuf) - sh.mu.Unlock() -} - -func (sh *serialhub) FindPortByName(portname string) (*serport, bool) { - sh.mu.Lock() - defer sh.mu.Unlock() - - for port := range sh.ports { - if strings.EqualFold(port.portConf.Name, portname) { - // we found our port - //spHandlerClose(port) - return port, true - } - } - return nil, false -} - -// List broadcasts a Json representation of the ports found -func (sp *SerialPortList) List() { - sp.portsLock.Lock() - ls, err := json.MarshalIndent(sp, "", "\t") - sp.portsLock.Unlock() - - if err != nil { - //log.Println(err) - h.broadcastSys <- []byte("Error creating json on port list " + - err.Error()) - } else { - h.broadcastSys <- ls - } -} - -// Run is the main loop for port discovery and management -func (sp *SerialPortList) Run() { - for retries := 0; retries < 10; retries++ { - sp.runSerialDiscovery() - - logrus.Errorf("Serial discovery stopped working, restarting it in 10 seconds...") - time.Sleep(10 * time.Second) - } - logrus.Errorf("Failed restarting serial discovery. Giving up...") -} - -func (sp *SerialPortList) runSerialDiscovery() { - // First ensure that all the discoveries are available - if err := Tools.Download("builtin", "serial-discovery", "latest", "keep"); err != nil { - logrus.Errorf("Error downloading serial-discovery: %s", err) - panic(err) - } - sd, err := Tools.GetLocation("serial-discovery") - if err != nil { - logrus.Errorf("Error downloading serial-discovery: %s", err) - panic(err) - } - d := discovery.NewClient("serial", sd+"/serial-discovery") - dLogger := logrus.WithField("discovery", "serial") - if *verbose { - d.SetLogger(dLogger) - } - d.SetUserAgent("arduino-create-agent/" + version) - if err := d.Run(); err != nil { - logrus.Errorf("Error running serial-discovery: %s", err) - panic(err) - } - defer d.Quit() - - events, err := d.StartSync(10) - if err != nil { - logrus.Errorf("Error starting event watcher on serial-discovery: %s", err) - panic(err) - } - - logrus.Infof("Serial discovery started, watching for events") - for ev := range events { - logrus.WithField("event", ev).Debugf("Serial discovery event") - switch ev.Type { - case "add": - sp.add(ev.Port) - case "remove": - sp.remove(ev.Port) - } - } - - sp.reset() - logrus.Errorf("Serial discovery stopped.") -} - -func (sp *SerialPortList) reset() { - sp.portsLock.Lock() - defer sp.portsLock.Unlock() - sp.Ports = []*SpPortItem{} -} - -func (sp *SerialPortList) add(addedPort *discovery.Port) { - if addedPort.Protocol != "serial" { - return - } - props := addedPort.Properties - if !props.ContainsKey("vid") { - return - } - vid, pid := props.Get("vid"), props.Get("pid") - if vid == "0x0000" || pid == "0x0000" { - return - } - if portsFilter != nil && !portsFilter.MatchString(addedPort.Address) { - logrus.Debugf("ignoring port not matching filter. port: %v\n", addedPort.Address) - return - } - - sp.portsLock.Lock() - defer sp.portsLock.Unlock() - - // If the port is already in the list, just update the metadata... - for _, oldPort := range sp.Ports { - if oldPort.Name == addedPort.Address { - oldPort.SerialNumber = props.Get("serialNumber") - oldPort.VendorID = vid - oldPort.ProductID = pid - return - } - } - // ...otherwise, add it to the list - sp.Ports = append(sp.Ports, &SpPortItem{ - Name: addedPort.Address, - SerialNumber: props.Get("serialNumber"), - VendorID: vid, - ProductID: pid, - Ver: version, - IsOpen: false, - IsPrimary: false, - Baud: 0, - BufferAlgorithm: "", - }) -} - -func (sp *SerialPortList) remove(removedPort *discovery.Port) { - sp.portsLock.Lock() - defer sp.portsLock.Unlock() - - // Remove the port from the list - sp.Ports = slices.DeleteFunc(sp.Ports, func(oldPort *SpPortItem) bool { - return oldPort.Name == removedPort.Address - }) -} - -// MarkPortAsOpened marks a port as opened by the user -func (sp *SerialPortList) MarkPortAsOpened(portname string) { - sp.portsLock.Lock() - defer sp.portsLock.Unlock() - port := sp.getPortByName(portname) - if port != nil { - port.IsOpen = true - } -} - -// MarkPortAsClosed marks a port as no more opened by the user -func (sp *SerialPortList) MarkPortAsClosed(portname string) { - sp.portsLock.Lock() - defer sp.portsLock.Unlock() - port := sp.getPortByName(portname) - if port != nil { - port.IsOpen = false - } -} - -func (sp *SerialPortList) getPortByName(portname string) *SpPortItem { - for _, port := range sp.Ports { - if port.Name == portname { - return port - } - } - return nil -} - -func spErr(err string) { - //log.Println("Sending err back: ", err) - //h.broadcastSys <- []byte(err) - h.broadcastSys <- []byte("{\"Error\" : \"" + err + "\"}") -} - -func spClose(portname string) { - if myport, ok := sh.FindPortByName(portname); ok { - h.broadcastSys <- []byte("Closing serial port " + portname) - myport.Close() - } else { - spErr("We could not find the serial port " + portname + " that you were trying to close.") - } -} - -func spWrite(arg string) { - // we will get a string of comXX asdf asdf asdf - //log.Println("Inside spWrite arg: " + arg) - arg = strings.TrimPrefix(arg, " ") - //log.Println("arg after trim: " + arg) - args := strings.SplitN(arg, " ", 3) - if len(args) != 3 { - errstr := "Could not parse send command: " + arg - //log.Println(errstr) - spErr(errstr) - return - } - bufferingMode := args[0] - portname := strings.Trim(args[1], " ") - data := args[2] - - //log.Println("The port to write to is:" + portname + "---") - //log.Println("The data is:" + data + "---") - - // See if we have this port open - port, ok := sh.FindPortByName(portname) - if !ok { - // we couldn't find the port, so send err - spErr("We could not find the serial port " + portname + " that you were trying to write to.") - return - } - - // see if bufferingMode is valid - switch bufferingMode { - case "send", "sendnobuf", "sendraw": - // valid buffering mode, go ahead - default: - spErr("Unsupported send command:" + args[0] + ". Please specify a valid one") - return - } - - // send it to the write channel - port.Write(data, bufferingMode) -} diff --git a/serialhub.go b/serialhub.go new file mode 100755 index 00000000..06a29003 --- /dev/null +++ b/serialhub.go @@ -0,0 +1,74 @@ +// Copyright 2022 Arduino SA +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Supports Windows, Linux, Mac, BeagleBone Black, and Raspberry Pi + +package main + +import ( + "strings" + "sync" +) + +type serialhub struct { + // Opened serial ports. + ports map[*serport]bool + mu sync.Mutex + + onRegister func(port *serport) + onUnregister func(port *serport) +} + +func newSerialHub(onRegister func(port *serport), onUnregister func(port *serport)) *serialhub { + return &serialhub{ + ports: make(map[*serport]bool), + onRegister: onRegister, + onUnregister: onUnregister, + } +} + +// Register serial ports from the connections. +func (sh *serialhub) Register(port *serport) { + sh.mu.Lock() + //log.Print("Registering a port: ", p.portConf.Name) + sh.onRegister(port) + sh.ports[port] = true + sh.mu.Unlock() +} + +// Unregister requests from connections. +func (sh *serialhub) Unregister(port *serport) { + sh.mu.Lock() + //log.Print("Unregistering a port: ", p.portConf.Name) + sh.onUnregister(port) + delete(sh.ports, port) + close(port.sendBuffered) + close(port.sendNoBuf) + sh.mu.Unlock() +} + +func (sh *serialhub) FindPortByName(portname string) (*serport, bool) { + sh.mu.Lock() + defer sh.mu.Unlock() + + for port := range sh.ports { + if strings.EqualFold(port.portConf.Name, portname) { + // we found our port + //spHandlerClose(port) + return port, true + } + } + return nil, false +} diff --git a/serialport.go b/serialport.go index 0d386bbf..4385a3e8 100755 --- a/serialport.go +++ b/serialport.go @@ -25,7 +25,6 @@ import ( "unicode/utf8" log "github.com/sirupsen/logrus" - serial "go.bug.st/serial" ) // SerialConfig is the serial port configuration @@ -61,6 +60,9 @@ type serport struct { BufferType string //bufferwatcher *BufferflowDummypause bufferwatcher Bufferflow + + OnMessage func([]byte) + OnClose func(*serport) } // SpPortMessage is the serial port message @@ -89,7 +91,7 @@ func (p *serport) reader(buftype string) { if p.isClosing.Load() { strmsg := "Shutting down reader on " + p.portConf.Name log.Println(strmsg) - h.broadcastSys <- []byte(strmsg) + p.OnMessage([]byte(strmsg)) break } @@ -143,15 +145,14 @@ func (p *serport) reader(buftype string) { if err == io.EOF || err == io.ErrUnexpectedEOF { // hit end of file log.Println("Hit end of file on serial port") - h.broadcastSys <- []byte("{\"Cmd\":\"OpenFail\",\"Desc\":\"Got EOF (End of File) on port which usually means another app other than Serial Port JSON Server is locking your port. " + err.Error() + "\",\"Port\":\"" + p.portConf.Name + "\",\"Baud\":" + strconv.Itoa(p.portConf.Baud) + "}") + p.OnMessage([]byte("{\"Cmd\":\"OpenFail\",\"Desc\":\"Got EOF (End of File) on port which usually means another app other than Serial Port JSON Server is locking your port. " + err.Error() + "\",\"Port\":\"" + p.portConf.Name + "\",\"Baud\":" + strconv.Itoa(p.portConf.Baud) + "}")) } if err != nil { log.Println(err) - h.broadcastSys <- []byte("Error reading on " + p.portConf.Name + " " + - err.Error() + " Closing port.") - h.broadcastSys <- []byte("{\"Cmd\":\"OpenFail\",\"Desc\":\"Got error reading on port. " + err.Error() + "\",\"Port\":\"" + p.portConf.Name + "\",\"Baud\":" + strconv.Itoa(p.portConf.Baud) + "}") + p.OnMessage([]byte("Error reading on " + p.portConf.Name + " " + err.Error() + " Closing port.")) + p.OnMessage([]byte("{\"Cmd\":\"OpenFail\",\"Desc\":\"Got error reading on port. " + err.Error() + "\",\"Port\":\"" + p.portConf.Name + "\",\"Baud\":" + strconv.Itoa(p.portConf.Baud) + "}")) p.isClosingDueToError = true break } @@ -209,7 +210,7 @@ func (p *serport) writerBuffered() { } msgstr := "writerBuffered just got closed. make sure you make a new one. port:" + p.portConf.Name log.Println(msgstr) - h.broadcastSys <- []byte(msgstr) + p.OnMessage([]byte(msgstr)) } // this method runs as its own thread because it's instantiated @@ -230,15 +231,22 @@ func (p *serport) writerNoBuf() { if err != nil { errstr := "Error writing to " + p.portConf.Name + " " + err.Error() + " Closing port." log.Print(errstr) - h.broadcastSys <- []byte(errstr) + p.OnMessage([]byte(errstr)) break } } msgstr := "Shutting down writer on " + p.portConf.Name log.Println(msgstr) - h.broadcastSys <- []byte(msgstr) + p.OnMessage([]byte(msgstr)) + p.portIo.Close() - serialPorts.List() + + // NOTE: by removing the 'serialPorts.List()' line, + // the list of serial ports are NOT sent to the websocket clients after a write is completed. + // This should not be an issue since the list are periodically called. + // Note also that the 'writerBuffered' and 'writerRaw' methods do not call it. + // serialPorts.List() + } // this method runs as its own thread because it's instantiated @@ -270,82 +278,7 @@ func (p *serport) writerRaw() { } msgstr := "writerRaw just got closed. make sure you make a new one. port:" + p.portConf.Name log.Println(msgstr) - h.broadcastSys <- []byte(msgstr) -} - -func spHandlerOpen(portname string, baud int, buftype string) { - - log.Print("Inside spHandler") - - var out bytes.Buffer - - out.WriteString("Opening serial port ") - out.WriteString(portname) - out.WriteString(" at ") - out.WriteString(strconv.Itoa(baud)) - out.WriteString(" baud") - log.Print(out.String()) - - conf := &SerialConfig{Name: portname, Baud: baud, RtsOn: true} - - mode := &serial.Mode{ - BaudRate: baud, - } - - sp, err := serial.Open(portname, mode) - log.Print("Just tried to open port") - if err != nil { - //log.Fatal(err) - log.Print("Error opening port " + err.Error()) - //h.broadcastSys <- []byte("Error opening port. " + err.Error()) - h.broadcastSys <- []byte("{\"Cmd\":\"OpenFail\",\"Desc\":\"Error opening port. " + err.Error() + "\",\"Port\":\"" + conf.Name + "\",\"Baud\":" + strconv.Itoa(conf.Baud) + "}") - - return - } - log.Print("Opened port successfully") - //p := &serport{send: make(chan []byte, 256), portConf: conf, portIo: sp} - // we can go up to 256,000 lines of gcode in the buffer - p := &serport{ - sendBuffered: make(chan string, 256000), - sendNoBuf: make(chan []byte), - sendRaw: make(chan string), - portConf: conf, - portIo: sp, - portName: portname, - BufferType: buftype} - - var bw Bufferflow - - switch buftype { - case "timed": - bw = NewBufferflowTimed(portname, h.broadcastSys) - case "timedraw": - bw = NewBufferflowTimedRaw(portname, h.broadcastSys) - case "default": - bw = NewBufferflowDefault(portname, h.broadcastSys) - default: - log.Panicf("unknown buffer type: %s", buftype) - } - - bw.Init() - p.bufferwatcher = bw - - sh.Register(p) - defer sh.Unregister(p) - - serialPorts.MarkPortAsOpened(portname) - serialPorts.List() - - // this is internally buffered thread to not send to serial port if blocked - go p.writerBuffered() - // this is thread to send to serial port regardless of block - go p.writerNoBuf() - // this is thread to send to serial port but with base64 decoding - go p.writerRaw() - - p.reader(buftype) - - serialPorts.List() + p.OnMessage([]byte(msgstr)) } func (p *serport) Close() { @@ -353,6 +286,5 @@ func (p *serport) Close() { p.bufferwatcher.Close() p.portIo.Close() - serialPorts.MarkPortAsClosed(p.portName) - serialPorts.List() + p.OnClose(p) } diff --git a/serialportlist.go b/serialportlist.go new file mode 100644 index 00000000..9b1fdfac --- /dev/null +++ b/serialportlist.go @@ -0,0 +1,201 @@ +package main + +import ( + "encoding/json" + "slices" + "sync" + "time" + + "github.com/arduino/arduino-create-agent/tools" + discovery "github.com/arduino/pluggable-discovery-protocol-handler/v2" + "github.com/sirupsen/logrus" +) + +type serialPortList struct { + Ports []*SpPortItem + portsLock sync.Mutex + + tools *tools.Tools `json:"-"` + OnList func([]byte) `json:"-"` + OnErr func(string) `json:"-"` +} + +// SpPortItem is the serial port item +type SpPortItem struct { + Name string + SerialNumber string + DeviceClass string + IsOpen bool + IsPrimary bool + Baud int + BufferAlgorithm string + Ver string + VendorID string + ProductID string +} + +func newSerialPortList(tools *tools.Tools, onList func(data []byte), onErr func(err string)) *serialPortList { + return &serialPortList{ + tools: tools, + OnList: onList, + OnErr: onErr, + } +} + +// List broadcasts a Json representation of the ports found +func (sp *serialPortList) List() { + sp.portsLock.Lock() + ls, err := json.MarshalIndent(sp, "", "\t") + sp.portsLock.Unlock() + + if err != nil { + sp.OnErr("Error creating json on port list " + err.Error()) + } else { + sp.OnList(ls) + } +} + +// MarkPortAsOpened marks a port as opened by the user +func (sp *serialPortList) MarkPortAsOpened(portname string) { + sp.portsLock.Lock() + defer sp.portsLock.Unlock() + port := sp.getPortByName(portname) + if port != nil { + port.IsOpen = true + } +} + +// MarkPortAsClosed marks a port as no more opened by the user +func (sp *serialPortList) MarkPortAsClosed(portname string) { + sp.portsLock.Lock() + defer sp.portsLock.Unlock() + port := sp.getPortByName(portname) + if port != nil { + port.IsOpen = false + } +} + +// Run is the main loop for port discovery and management +func (sp *serialPortList) Run() { + for retries := 0; retries < 10; retries++ { + sp.runSerialDiscovery() + + logrus.Errorf("Serial discovery stopped working, restarting it in 10 seconds...") + time.Sleep(10 * time.Second) + } + logrus.Errorf("Failed restarting serial discovery. Giving up...") +} + +func (sp *serialPortList) runSerialDiscovery() { + // First ensure that all the discoveries are available + noOpProgress := func(msg string) {} + if err := sp.tools.Download("builtin", "serial-discovery", "latest", "keep", noOpProgress); err != nil { + logrus.Errorf("Error downloading serial-discovery: %s", err) + panic(err) + } + sd, err := sp.tools.GetLocation("serial-discovery") + if err != nil { + logrus.Errorf("Error downloading serial-discovery: %s", err) + panic(err) + } + d := discovery.NewClient("serial", sd+"/serial-discovery") + dLogger := logrus.WithField("discovery", "serial") + if *verbose { + d.SetLogger(dLogger) + } + d.SetUserAgent("arduino-create-agent/" + version) + if err := d.Run(); err != nil { + logrus.Errorf("Error running serial-discovery: %s", err) + panic(err) + } + defer d.Quit() + + events, err := d.StartSync(10) + if err != nil { + logrus.Errorf("Error starting event watcher on serial-discovery: %s", err) + panic(err) + } + d.List() + + logrus.Infof("Serial discovery started, watching for events") + for ev := range events { + logrus.WithField("event", ev).Debugf("Serial discovery event") + switch ev.Type { + case "add": + sp.add(ev.Port) + case "remove": + sp.remove(ev.Port) + } + } + + sp.reset() + logrus.Errorf("Serial discovery stopped.") +} + +func (sp *serialPortList) reset() { + sp.portsLock.Lock() + defer sp.portsLock.Unlock() + sp.Ports = []*SpPortItem{} +} + +func (sp *serialPortList) add(addedPort *discovery.Port) { + if addedPort.Protocol != "serial" { + return + } + props := addedPort.Properties + if !props.ContainsKey("vid") { + return + } + vid, pid := props.Get("vid"), props.Get("pid") + if vid == "0x0000" || pid == "0x0000" { + return + } + if portsFilter != nil && !portsFilter.MatchString(addedPort.Address) { + logrus.Debugf("ignoring port not matching filter. port: %v\n", addedPort.Address) + return + } + + sp.portsLock.Lock() + defer sp.portsLock.Unlock() + + // If the port is already in the list, just update the metadata... + for _, oldPort := range sp.Ports { + if oldPort.Name == addedPort.Address { + oldPort.SerialNumber = props.Get("serialNumber") + oldPort.VendorID = vid + oldPort.ProductID = pid + return + } + } + // ...otherwise, add it to the list + sp.Ports = append(sp.Ports, &SpPortItem{ + Name: addedPort.Address, + SerialNumber: props.Get("serialNumber"), + VendorID: vid, + ProductID: pid, + Ver: version, + IsOpen: false, + IsPrimary: false, + Baud: 0, + BufferAlgorithm: "", + }) +} + +func (sp *serialPortList) remove(removedPort *discovery.Port) { + sp.portsLock.Lock() + defer sp.portsLock.Unlock() + + // Remove the port from the list + sp.Ports = slices.DeleteFunc(sp.Ports, func(oldPort *SpPortItem) bool { + return oldPort.Name == removedPort.Address + }) +} + +func (sp *serialPortList) getPortByName(portname string) *SpPortItem { + for _, port := range sp.Ports { + if port.Name == portname { + return port + } + } + return nil +} diff --git a/tools/download.go b/tools/download.go index 8c4a37a6..da7df680 100644 --- a/tools/download.go +++ b/tools/download.go @@ -42,7 +42,7 @@ import ( // If version is not "latest" and behaviour is "replace", it will download the // version again. If instead behaviour is "keep" it will not download the version // if it already exists. -func (t *Tools) Download(pack, name, version, behaviour string) error { +func (t *Tools) Download(pack, name, version, behaviour string, report func(msg string)) error { t.tools.SetBehaviour(behaviour) _, err := t.tools.Install(context.Background(), &tools.ToolPayload{Name: name, Version: version, Packager: pack}) @@ -58,16 +58,16 @@ func (t *Tools) Download(pack, name, version, behaviour string) error { // if the tool contains a post_install script, run it: it means it is a tool that needs to install drivers // AFAIK this is only the case for the windows-driver tool - err = t.installDrivers(safePath) + err = t.installDrivers(safePath, report) if err != nil { return err } // Ensure that the files are executable - t.logger("Ensure that the files are executable") + report("Ensure that the files are executable") // Update the tool map - t.logger("Updating map with location " + safePath) + report("Updating map with location " + safePath) t.setMapValue(name, safePath) t.setMapValue(name+"-"+version, safePath) @@ -75,7 +75,7 @@ func (t *Tools) Download(pack, name, version, behaviour string) error { return nil } -func (t *Tools) installDrivers(location string) error { +func (t *Tools) installDrivers(location string, report func(msg string)) error { OkPressed := 6 extension := ".bat" // add .\ to force locality @@ -86,11 +86,11 @@ func (t *Tools) installDrivers(location string) error { preamble = "./" } if _, err := os.Stat(filepath.Join(location, "post_install"+extension)); err == nil { - t.logger("Installing drivers") + report("Installing drivers") ok := MessageBox("Installing drivers", "We are about to install some drivers needed to use Arduino/Genuino boards\nDo you want to continue?") if ok == OkPressed { os.Chdir(location) - t.logger(preamble + "post_install" + extension) + report(preamble + "post_install" + extension) oscmd := exec.Command(preamble + "post_install" + extension) if runtime.GOOS != "linux" { // spawning a shell could be the only way to let the user type his password diff --git a/tools/download_test.go b/tools/download_test.go index 96a105fd..b99f1077 100644 --- a/tools/download_test.go +++ b/tools/download_test.go @@ -130,12 +130,12 @@ func TestDownload(t *testing.T) { IndexFile: *paths.New("testdata", "test_tool_index.json"), LastRefresh: time.Now(), } - testTools := New(tempDirPath, &testIndex, func(msg string) { t.Log(msg) }, utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey))) + testTools := New(tempDirPath, &testIndex, utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey))) for _, tc := range testCases { t.Run(tc.name+"-"+tc.version, func(t *testing.T) { // Download the tool - err := testTools.Download("arduino-test", tc.name, tc.version, "replace") + err := testTools.Download("arduino-test", tc.name, tc.version, "replace", func(msg string) { t.Log(msg) }) require.NoError(t, err) // Check that the tool has been downloaded @@ -177,8 +177,8 @@ func TestCorruptedInstalled(t *testing.T) { defer fileJSON.Close() _, err = fileJSON.Write([]byte("Hello")) require.NoError(t, err) - testTools := New(tempDirPath, &testIndex, func(msg string) { t.Log(msg) }, utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey))) + testTools := New(tempDirPath, &testIndex, utilities.MustParseRsaPublicKey([]byte(globals.ArduinoSignaturePubKey))) // Download the tool - err = testTools.Download("arduino-test", "avrdude", "6.3.0-arduino17", "keep") + err = testTools.Download("arduino-test", "avrdude", "6.3.0-arduino17", "keep", func(msg string) { t.Log(msg) }) require.NoError(t, err) } diff --git a/tools/tools.go b/tools/tools.go index f371126b..0ad95763 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -46,7 +46,6 @@ import ( type Tools struct { directory *paths.Path index *index.Resource - logger func(msg string) installed map[string]string mutex sync.RWMutex tools *pkgs.Tools @@ -56,11 +55,10 @@ type Tools struct { // The New functions accept the directory to use to host the tools, // an index (used to download the tools), // and a logger to log the operations -func New(directory *paths.Path, index *index.Resource, logger func(msg string), signPubKey *rsa.PublicKey) *Tools { +func New(directory *paths.Path, index *index.Resource, signPubKey *rsa.PublicKey) *Tools { t := &Tools{ directory: directory, index: index, - logger: logger, installed: map[string]string{}, mutex: sync.RWMutex{}, tools: pkgs.New(index, directory.String(), "replace", signPubKey), diff --git a/update.go b/update.go index 33c028bc..a12f1feb 100644 --- a/update.go +++ b/update.go @@ -30,20 +30,23 @@ package main import ( + "github.com/arduino/arduino-create-agent/systray" "github.com/arduino/arduino-create-agent/updater" "github.com/gin-gonic/gin" ) -func updateHandler(c *gin.Context) { - restartPath, err := updater.CheckForUpdates(version, *updateURL, *appName) - if err != nil { - c.JSON(500, gin.H{"error": err.Error()}) - return - } - c.JSON(200, gin.H{"success": "Please wait a moment while the agent reboots itself"}) - if restartPath == "quit" { - Systray.Quit() - } else { - Systray.RestartWith(restartPath) +func updateHandler(s *systray.Systray) func(c *gin.Context) { + return func(c *gin.Context) { + restartPath, err := updater.CheckForUpdates(version, *updateURL, *appName) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + c.JSON(200, gin.H{"success": "Please wait a moment while the agent reboots itself"}) + if restartPath == "quit" { + s.Quit() + } else { + s.RestartWith(restartPath) + } } }