Skip to content

Commit

Permalink
feat(vmtranslator): add stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
qazwsxedckll committed Feb 5, 2024
1 parent c705782 commit b44b7d0
Show file tree
Hide file tree
Showing 18 changed files with 2,066 additions and 37 deletions.
209 changes: 193 additions & 16 deletions vmtranslator/code_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@ package main

import (
"bytes"
"errors"
"fmt"
"io"
"io/fs"
"regexp"
"strconv"
)

var validLabel = `^[A-Za-z_:.][\w_:.]*$`

type CodeWriter struct {
fileName string
w io.WriteCloser
jump int
retIndex int
function string
}

func NewCodeWriter(w io.WriteCloser) *CodeWriter {
Expand All @@ -19,7 +26,7 @@ func NewCodeWriter(w io.WriteCloser) *CodeWriter {
}
}

func (c *CodeWriter) SetFileNmae(fileName string) {
func (c *CodeWriter) SetFileName(fileName string) {
c.fileName = fileName
}

Expand All @@ -28,13 +35,14 @@ func (c *CodeWriter) WriteArithmetic(command string) {

switch command {
case "add":
popBinary(&buffer)
popUnaryAndGetTop(&buffer)
buffer.WriteString("M=M+D\n")
case "sub":
popBinary(&buffer)
popUnaryAndGetTop(&buffer)
buffer.WriteString("M=M-D\n")
case "neg":
popUnary(&buffer)
buffer.WriteString("@SP\n")
buffer.WriteString("A=M-1\n")
buffer.WriteString("M=-M\n")
case "eq":
c.jump = writeCompare(&buffer, c.jump, "JEQ")
Expand All @@ -43,13 +51,14 @@ func (c *CodeWriter) WriteArithmetic(command string) {
case "lt":
c.jump = writeCompare(&buffer, c.jump, "JLT")
case "and":
popBinary(&buffer)
popUnaryAndGetTop(&buffer)
buffer.WriteString("M=M&D\n")
case "or":
popBinary(&buffer)
popUnaryAndGetTop(&buffer)
buffer.WriteString("M=M|D\n")
case "not":
popUnary(&buffer)
buffer.WriteString("@SP\n")
buffer.WriteString("A=M-1\n")
buffer.WriteString("M=!M\n")
}

Expand Down Expand Up @@ -137,7 +146,6 @@ func (c *CodeWriter) WritePushPop(command CommandType, segment string, index int

popUnary(&buffer)

buffer.WriteString("D=M\n")
buffer.WriteString("@R13\n")
buffer.WriteString("A=M\n")
buffer.WriteString("M=D\n")
Expand All @@ -151,10 +159,180 @@ func (c *CodeWriter) WritePushPop(command CommandType, segment string, index int
}
}

func (c *CodeWriter) WriteInit() {
buffer := bytes.Buffer{}
buffer.WriteString("@256\n")
buffer.WriteString("D=A\n")
buffer.WriteString("@SP\n")
buffer.WriteString("M=D\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}

c.WriteCall("Sys.init", 0)
}

func (c *CodeWriter) WriteLabel(label string) {
reg := regexp.MustCompile(validLabel)
if !reg.MatchString(label) {
panic("invalid label: " + label)
}

buffer := bytes.Buffer{}
buffer.WriteString("(" + c.function + "$" + label + ")\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}
}

func (c *CodeWriter) WriteGoto(label string) {
buffer := bytes.Buffer{}
buffer.WriteString("@" + c.function + "$" + label + "\n")
buffer.WriteString("0;JMP\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}
}

func (c *CodeWriter) WriteIf(label string) {
buffer := bytes.Buffer{}
popUnary(&buffer)
buffer.WriteString("@" + c.function + "$" + label + "\n")
buffer.WriteString("D;JNE\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}
}

func (c *CodeWriter) WriteCall(functionName string, numArgs int) {
buffer := bytes.Buffer{}
returnAddress := "_" + functionName + "_RETURN_" + strconv.Itoa(c.retIndex)
c.retIndex++
// push return address
buffer.WriteString("@" + returnAddress + "\n")
buffer.WriteString("D=A\n")
push(&buffer)
// push LCL, ARG, THIS, THAT
buffer.WriteString("@LCL\n")
buffer.WriteString("D=M\n")
push(&buffer)
buffer.WriteString("@ARG\n")
buffer.WriteString("D=M\n")
push(&buffer)
buffer.WriteString("@THIS\n")
buffer.WriteString("D=M\n")
push(&buffer)
buffer.WriteString("@THAT\n")
buffer.WriteString("D=M\n")
push(&buffer)
// ARG = SP - (n + 5)
buffer.WriteString("@" + strconv.Itoa(numArgs+5) + "\n")
buffer.WriteString("D=D-A\n")
buffer.WriteString("@ARG\n")
buffer.WriteString("M=D\n")
// LCL = SP
buffer.WriteString("@SP\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@LCL\n")
buffer.WriteString("M=D\n")
// goto function
buffer.WriteString("@" + functionName + "\n")
buffer.WriteString("0;JMP\n")
// return address
buffer.WriteString("(" + returnAddress + ")\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}
}

func (c *CodeWriter) WriteReturn() {
buffer := bytes.Buffer{}
// FRAME = LCL
buffer.WriteString("@LCL\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@R13\n")
buffer.WriteString("M=D\n")
// RET = *(FRAME - 5)
// if arg num is 0, pop() will overwrite RET, so we need to save it
buffer.WriteString("@5\n")
buffer.WriteString("A=D-A\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@R14\n")
buffer.WriteString("M=D\n")
// *ARG = pop()
popUnary(&buffer)
buffer.WriteString("@ARG\n")
buffer.WriteString("A=M\n")
buffer.WriteString("M=D\n")
// SP = ARG + 1
buffer.WriteString("@ARG\n")
buffer.WriteString("D=M+1\n")
buffer.WriteString("@SP\n")
buffer.WriteString("M=D\n")
// THAT = *(FRAME - 1)
buffer.WriteString("@R13\n")
buffer.WriteString("AM=M-1\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@THAT\n")
buffer.WriteString("M=D\n")
// THIS = *(FRAME - 2)
buffer.WriteString("@R13\n")
buffer.WriteString("AM=M-1\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@THIS\n")
buffer.WriteString("M=D\n")
// ARG = *(FRAME - 3)
buffer.WriteString("@R13\n")
buffer.WriteString("AM=M-1\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@ARG\n")
buffer.WriteString("M=D\n")
// LCL = *(FRAME - 4)
buffer.WriteString("@R13\n")
buffer.WriteString("AM=M-1\n")
buffer.WriteString("D=M\n")
buffer.WriteString("@LCL\n")
buffer.WriteString("M=D\n")
// goto RET
buffer.WriteString("@R14\n")
buffer.WriteString("A=M\n")
buffer.WriteString("0;JMP\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}
}

func (c *CodeWriter) WriteFunction(functionName string, numLocals int) {
reg := regexp.MustCompile(validLabel)
if !reg.MatchString(functionName) {
panic("invalid function name: " + functionName)
}

buffer := bytes.Buffer{}
c.function = functionName
buffer.WriteString("(" + functionName + ")\n")
_, err := c.w.Write(buffer.Bytes())
if err != nil {
panic(err)
}

for i := 0; i < numLocals; i++ {
c.WritePushPop(C_PUSH, "constant", 0)
}
}

func (c *CodeWriter) Close() {
err := c.w.Close()
if err != nil {
panic(err)
if !errors.Is(err, fs.ErrClosed) {
panic(err)
}
}
}

Expand All @@ -171,7 +349,7 @@ func writeTrue(buf *bytes.Buffer) {
}

func writeCompare(buffer *bytes.Buffer, jump int, compare string) int {
popBinary(buffer)
popUnaryAndGetTop(buffer)
buffer.WriteString("D=M-D\n")
buffer.WriteString("@JUMP_" + strconv.Itoa(jump) + "\n")
buffer.WriteString("D;" + compare + "\n")
Expand All @@ -192,18 +370,17 @@ func writeOffset(buf *bytes.Buffer, index int) {
buf.WriteString("D=M\n")
}

// one value in D, one value in M
func popBinary(buf *bytes.Buffer) {
buf.WriteString("@SP\n")
buf.WriteString("AM=M-1\n")
buf.WriteString("D=M\n")
// pop value and put it in D, then point to the last value
func popUnaryAndGetTop(buf *bytes.Buffer) {
popUnary(buf)
buf.WriteString("A=A-1\n")
}

// one value in M
// pop value and put it in D
func popUnary(buf *bytes.Buffer) {
buf.WriteString("@SP\n")
buf.WriteString("AM=M-1\n")
buf.WriteString("D=M\n")
}

func push(buf *bytes.Buffer) {
Expand Down
81 changes: 62 additions & 19 deletions vmtranslator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,66 @@ func main() {
}

arg := os.Args[1]
err := filepath.WalkDir(arg, func(path string, d os.DirEntry, err error) error {

if strings.HasSuffix(arg, ".vm") {
out, err := os.Create(strings.TrimSuffix(filepath.Base(arg), ".vm") + ".asm")
if err != nil {
panic(err)
}

codeWriter := NewCodeWriter(out)
defer codeWriter.Close()
codeWriter.WriteInit()

in, err := os.Open(arg)
if err != nil {
return err
panic(err)
}

if strings.HasSuffix(path, ".vm") {
in, err := os.Open(path)
parse(in, codeWriter, arg)
} else {
vmFound := false
var codeWriter *CodeWriter

err := filepath.WalkDir(arg, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}
parse(in, path)
}

return nil
})
if err != nil {
panic(err)
}
}
if strings.HasSuffix(path, ".vm") {
if !vmFound {
dir := filepath.Dir(path)
out, err := os.Create(dir + string(filepath.Separator) + filepath.Base(dir) + ".asm")
if err != nil {
panic(err)
}
codeWriter = NewCodeWriter(out)
codeWriter.WriteInit()
vmFound = true
}

func parse(in io.Reader, path string) {
out, err := os.Create(path[:len(path)-3] + ".asm")
if err != nil {
panic(err)
in, err := os.Open(path)
if err != nil {
return err
}
parse(in, codeWriter, path)
} else {
if codeWriter != nil {
codeWriter.Close()
}
vmFound = false
}

return nil
})
if err != nil {
panic(err)
}
}
}

codeWriter := NewCodeWriter(out)
codeWriter.SetFileNmae(strings.TrimSuffix(filepath.Base(path), ".vm"))
defer codeWriter.w.Close()
func parse(in io.Reader, codeWriter *CodeWriter, path string) {
codeWriter.SetFileName(strings.TrimSuffix(filepath.Base(path), ".vm"))

parser, err := NewParser(in)
if err != nil {
Expand All @@ -59,6 +90,18 @@ func parse(in io.Reader, path string) {
codeWriter.WriteArithmetic(parser.Arg1())
case C_PUSH, C_POP:
codeWriter.WritePushPop(cmdType, parser.Arg1(), parser.Arg2())
case C_LABEL:
codeWriter.WriteLabel(parser.Arg1())
case C_GOTO:
codeWriter.WriteGoto(parser.Arg1())
case C_IF:
codeWriter.WriteIf(parser.Arg1())
case C_FUNCTION:
codeWriter.WriteFunction(parser.Arg1(), parser.Arg2())
case C_RETURN:
codeWriter.WriteReturn()
case C_CALL:
codeWriter.WriteCall(parser.Arg1(), parser.Arg2())
default:
panic(fmt.Sprintf("unknown command type: %v", cmdType))
}
Expand Down
Loading

0 comments on commit b44b7d0

Please sign in to comment.