Skip to content

Commit

Permalink
Merge pull request #7 from streamer45/model-v5
Browse files Browse the repository at this point in the history
Model v5
  • Loading branch information
streamer45 authored Jul 8, 2024
2 parents 4567352 + f66dddc commit dbd4af5
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 131 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
on: [push]
name: CI
jobs:
test:
env:
GOPATH: ${{ github.workspace }}

defaults:
run:
working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}

strategy:
matrix:
go-version: [1.21.x]

runs-on: ubuntu-latest

steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Checkout Code
uses: actions/checkout@v2
with:
path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}
- name: Install ONNX
run: |
wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.2/onnxruntime-linux-x64-1.16.2.tgz
tar xf onnxruntime-linux-x64-1.16.2.tgz
sudo cp ./onnxruntime-linux-x64-1.16.2/include/* /usr/local/include
sudo cp ./onnxruntime-linux-x64-1.16.2/lib/* /usr/local/lib
sudo ldconfig
- name: Execute Tests
run: |
go mod download
go mod verify
make test
env:
CI: true
29 changes: 29 additions & 0 deletions .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: golangci-lint
on: [push]
permissions:
contents: read
jobs:
golangci:
name: lint
strategy:
matrix:
go-version: [1.21.x]
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Install ONNX
run: |
wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.2/onnxruntime-linux-x64-1.16.2.tgz
tar xf onnxruntime-linux-x64-1.16.2.tgz
sudo cp ./onnxruntime-linux-x64-1.16.2/include/* /usr/local/include
sudo cp ./onnxruntime-linux-x64-1.16.2/lib/* /usr/local/lib
sudo ldconfig
- name: Checkout Code
uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.57.2
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,28 @@

- [Golang](https://go.dev/doc/install) >= v1.21
- A C compiler (e.g. GCC)
- ONNX Runtime
- A [Silero VAD](https://github.com/snakers4/silero-vad) model
- ONNX Runtime (v1.16.2)
- A [Silero VAD](https://github.com/snakers4/silero-vad) model (v5)

### Development

In order to build and/or run this library, you need to export (or pass) some env variables to point to the ONNX runtime files.

#### Linux

```sh
LD_RUN_PATH="/usr/local/lib/onnxruntime-linux-x64-1.16.2/lib"
LIBRARY_PATH="/usr/local/lib/onnxruntime-linux-x64-1.16.2/lib"
C_INCLUDE_PATH="/usr/local/include/onnxruntime-linux-x64-1.16.2/include"
```

#### Darwin (MacOS)

```sh
LIBRARY_PATH="/usr/local/lib/onnxruntime-linux-x64-1.16.2/lib"
C_INCLUDE_PATH="/usr/local/include/onnxruntime-linux-x64-1.16.2/include"
sudo update_dyld_shared_cache
```

### License

Expand Down
38 changes: 16 additions & 22 deletions speech/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@ import (
)

const (
hcLen = 2 * 1 * 64
stateLen = 2 * 1 * 128
)

type DetectorConfig struct {
// The path to the ONNX Silero VAD model file to load.
ModelPath string
// The sampling rate of the input audio samples. Supported values are 8000 and 16000.
SampleRate int
// The number of samples to process at each infer.
WindowSize int
// The probability threshold above which we detect speech. A good default is 0.5.
Threshold float32
// The duration of silence to wait for each speech segment before separating it.
Expand All @@ -39,11 +37,6 @@ func (c DetectorConfig) IsValid() error {
return fmt.Errorf("invalid SampleRate: valid values are 8000 and 16000")
}

if (c.SampleRate == 16000 && c.WindowSize != 512 && c.WindowSize != 1024 && c.WindowSize != 1536) ||
(c.SampleRate == 8000 && c.WindowSize != 256 && c.WindowSize != 512 && c.WindowSize != 768) {
return fmt.Errorf("invalid WindowSize: valid values are 512, 1024, 1536 for 16000 sample rate and 256, 512, 768 for 8000 sample rate")
}

if c.Threshold <= 0 || c.Threshold >= 1 {
return fmt.Errorf("invalid Threshold: should be in range (0, 1)")
}
Expand All @@ -69,8 +62,7 @@ type Detector struct {

cfg DetectorConfig

h [hcLen]float32
c [hcLen]float32
state [stateLen]float32

currSample int
triggered bool
Expand Down Expand Up @@ -138,11 +130,9 @@ func NewDetector(cfg DetectorConfig) (*Detector, error) {

sd.cStrings["input"] = C.CString("input")
sd.cStrings["sr"] = C.CString("sr")
sd.cStrings["h"] = C.CString("h")
sd.cStrings["c"] = C.CString("c")
sd.cStrings["state"] = C.CString("state")
sd.cStrings["stateN"] = C.CString("stateN")
sd.cStrings["output"] = C.CString("output")
sd.cStrings["hn"] = C.CString("hn")
sd.cStrings["cn"] = C.CString("cn")

return &sd, nil
}
Expand All @@ -160,7 +150,12 @@ func (sd *Detector) Detect(pcm []float32) ([]Segment, error) {
return nil, fmt.Errorf("invalid nil detector")
}

if len(pcm) < sd.cfg.WindowSize {
windowSize := 512
if sd.cfg.SampleRate == 8000 {
windowSize = 256
}

if len(pcm) < windowSize {
return nil, fmt.Errorf("not enough samples")
}

Expand All @@ -170,21 +165,21 @@ func (sd *Detector) Detect(pcm []float32) ([]Segment, error) {
speechPadSamples := sd.cfg.SpeechPadMs * sd.cfg.SampleRate / 1000

var segments []Segment
for i := 0; i < len(pcm)-sd.cfg.WindowSize; i += sd.cfg.WindowSize {
speechProb, err := sd.infer(pcm[i : i+sd.cfg.WindowSize])
for i := 0; i < len(pcm)-windowSize; i += windowSize {
speechProb, err := sd.infer(pcm[i : i+windowSize])
if err != nil {
return nil, fmt.Errorf("infer failed: %w", err)
}

sd.currSample += sd.cfg.WindowSize
sd.currSample += windowSize

if speechProb >= sd.cfg.Threshold && sd.tempEnd != 0 {
sd.tempEnd = 0
}

if speechProb >= sd.cfg.Threshold && !sd.triggered {
sd.triggered = true
speechStartAt := (float64(sd.currSample-sd.cfg.WindowSize-speechPadSamples) / float64(sd.cfg.SampleRate))
speechStartAt := (float64(sd.currSample-windowSize-speechPadSamples) / float64(sd.cfg.SampleRate))

// We clamp at zero since due to padding the starting position could be negative.
if speechStartAt < 0 {
Expand Down Expand Up @@ -233,9 +228,8 @@ func (sd *Detector) Reset() error {
sd.currSample = 0
sd.triggered = false
sd.tempEnd = 0
for i := 0; i < hcLen; i++ {
sd.h[i] = 0
sd.c[i] = 0
for i := 0; i < stateLen; i++ {
sd.state[i] = 0
}

return nil
Expand Down
47 changes: 12 additions & 35 deletions speech/detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,11 @@ func TestDetectorConfigIsValid(t *testing.T) {
},
err: "invalid SampleRate: valid values are 8000 and 16000",
},
{
name: "invalid WindowSize",
cfg: DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
},
err: "invalid WindowSize: valid values are 512, 1024, 1536 for 16000 sample rate and 256, 512, 768 for 8000 sample rate",
},
{
name: "invalid WindowSize for rate",
cfg: DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 768,
},
err: "invalid WindowSize: valid values are 512, 1024, 1536 for 16000 sample rate and 256, 512, 768 for 8000 sample rate",
},
{
name: "invalid Threshold",
cfg: DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 1536,
Threshold: 0,
},
err: "invalid Threshold: should be in range (0, 1)",
Expand All @@ -63,7 +45,6 @@ func TestDetectorConfigIsValid(t *testing.T) {
cfg: DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 1536,
Threshold: 0.5,
MinSilenceDurationMs: -1,
},
Expand All @@ -74,7 +55,6 @@ func TestDetectorConfigIsValid(t *testing.T) {
cfg: DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 1536,
Threshold: 0.5,
SpeechPadMs: -1,
},
Expand All @@ -85,7 +65,6 @@ func TestDetectorConfigIsValid(t *testing.T) {
cfg: DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 1536,
Threshold: 0.5,
},
},
Expand All @@ -107,7 +86,6 @@ func TestNewDetector(t *testing.T) {
cfg := DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 1536,
Threshold: 0.5,
}

Expand All @@ -123,7 +101,6 @@ func TestSpeechDetection(t *testing.T) {
cfg := DetectorConfig{
ModelPath: "../testfiles/silero_vad.onnx",
SampleRate: 16000,
WindowSize: 1536,
Threshold: 0.5,
}

Expand Down Expand Up @@ -154,15 +131,15 @@ func TestSpeechDetection(t *testing.T) {
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 1.056,
SpeechEndAt: 1.728,
SpeechStartAt: 1.088,
SpeechEndAt: 1.632,
},
{
SpeechStartAt: 2.88,
SpeechStartAt: 2.912,
SpeechEndAt: 3.264,
},
{
SpeechStartAt: 4.416,
SpeechStartAt: 4.448,
SpeechEndAt: 0,
},
}, segments)
Expand All @@ -177,15 +154,15 @@ func TestSpeechDetection(t *testing.T) {
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 1.056,
SpeechEndAt: 1.728,
SpeechStartAt: 1.088,
SpeechEndAt: 1.632,
},
{
SpeechStartAt: 2.88,
SpeechStartAt: 2.912,
SpeechEndAt: 3.264,
},
{
SpeechStartAt: 4.416,
SpeechStartAt: 4.448,
SpeechEndAt: 0,
},
}, segments)
Expand All @@ -205,15 +182,15 @@ func TestSpeechDetection(t *testing.T) {
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 1.056 - 0.01,
SpeechEndAt: 1.728 + 0.01,
SpeechStartAt: 1.088 - 0.01,
SpeechEndAt: 1.632 + 0.01,
},
{
SpeechStartAt: 2.88 - 0.01,
SpeechStartAt: 2.912 - 0.01,
SpeechEndAt: 3.264 + 0.01,
},
{
SpeechStartAt: 4.416 - 0.01,
SpeechStartAt: 4.448 - 0.01,
SpeechEndAt: 0,
},
}, segments)
Expand Down
Loading

0 comments on commit dbd4af5

Please sign in to comment.