forked from haoheliu/AudioLDM
-
Notifications
You must be signed in to change notification settings - Fork 6
/
plugin.py
164 lines (157 loc) · 5.8 KB
/
plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from tuneflow_py import TuneflowPlugin, ParamDescriptor, Song, WidgetType, InjectSource, TuneflowPluginTriggerData
from typing import Dict, Any
from audioldm import text_to_audio, build_model
from pathlib import Path
import traceback
from io import BytesIO
import soundfile as sf
from typing import List
model_path = str(Path(__file__).parent.joinpath('ckpt/ldm_trimmed.ckpt').absolute())
model = build_model(ckpt_path=model_path)
class AudioLDMPlugin(TuneflowPlugin):
@staticmethod
def provider_id() -> str:
return 'andantei'
@staticmethod
def plugin_id() -> str:
return 'audioldm-generate'
@staticmethod
def params(song: Song) -> Dict[str, ParamDescriptor]:
return {
"prompt": {
"displayName": {
"en": "Prompt",
"zh": "提示词"
},
"description": {
"en": "A short sentence to describe the audio you want to generate",
"zh": "用一段简短的文字描述你想要的音频"
},
"defaultValue": None,
"widget": {
"type": WidgetType.TextArea.value,
"config": {
"placeholder": {
"zh": "样例:斧头正在伐木",
"en": "e.g. A hammer is hitting a tree"
},
"maxLength": 140
}
}
},
"guidance_scale": {
"displayName": {
"en": "Guidance Scale",
"zh": "提示强度"
},
"description": {
"en": "Larger value yields results more relavant to the prompt, smaller value yields more diversity",
"zh": "值越大,生成结果越贴近提示词,值越小,生成结果越发散"
},
"defaultValue": 2.5,
"widget": {
"type": WidgetType.InputNumber.value,
"config": {
"minValue": 0.1,
"maxValue": 10,
"step": 0.1
}
}
},
"duration": {
"displayName": {
"en": "Duration (seconds)",
"zh": "长度 (秒)"
},
"defaultValue": 10,
"widget": {
"type": WidgetType.InputNumber.value,
"config": {
"minValue": 2.5,
"maxValue": 10,
"step": 2.5
}
}
},
"random_seed": {
"displayName": {
"en": "Random Seed",
"zh": "随机因子"
},
"defaultValue": 42,
"description": {
"en": "Using the same params and random seed generates the same response",
"zh": "使用相同的参数值和随机因子可以生成相同的结果"
},
"widget": {
"type": WidgetType.InputNumber.value,
"config": {
"minValue": 1,
"maxValue": 99999999,
"step": 1
}
}
},
"playhead_tick": {
"displayName": {
"zh": '当前指针位置',
"en": 'Playhead Position',
},
"defaultValue": None,
"widget": {
"type": WidgetType.InputNumber.value,
},
"hidden": True,
"injectFrom": InjectSource.TickAtPlayheadSnappedToBeat.value,
},
}
@staticmethod
def run(song: Song, params: Dict[str, Any]):
trigger: TuneflowPluginTriggerData = params["trigger"]
track_id =trigger["entities"][0]["trackId"]
track = song.get_track_by_id(track_id=track_id)
if track is None:
raise Exception('track not found')
# TODO: Support prompt i18n
file_bytes_list = AudioLDMPlugin._text2audio(
model,
text=params["prompt"],
duration=params["duration"],
guidance_scale=params["guidance_scale"],
# Randomize seed.
random_seed=params["random_seed"])
for file_bytes in file_bytes_list:
try:
file_bytes.seek(0)
track.create_audio_clip(clip_start_tick=0, audio_clip_data={
"audio_data": {
"format": "wav",
"data": file_bytes.read()
},
"duration": params["duration"],
"start_tick": params["playhead_tick"]
})
break
except:
print(traceback.format_exc())
@staticmethod
def _text2audio(model, text, duration, guidance_scale, random_seed):
# print(text, length, guidance_scale)
waveform = text_to_audio(
model,
text=text,
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
n_candidate_gen_per_text=3,
batchsize=1,
)
return AudioLDMPlugin._save_wave(waveform)
@staticmethod
def _save_wave(waveform):
saved_file_bytes: List[BytesIO] = []
for i in range(waveform.shape[0]):
file_bytes = BytesIO()
sf.write(file_bytes, waveform[i, 0], samplerate=16000, format="wav")
saved_file_bytes.append(file_bytes)
return saved_file_bytes