Skip to content

Commit 8f6d947

Browse files
committed
release
1 parent f6515b4 commit 8f6d947

File tree

5 files changed

+567
-2
lines changed

5 files changed

+567
-2
lines changed

README.md

+104-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,107 @@
11
# CSM
22

3-
A Conversational Speech Generation Model
3+
**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on HuggingFace](https://huggingface.co/sesame/csm_1b).
44

5-
- Landing soon. Our open source model that powers our conversational speech generation [demo]( https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice#demo)
5+
---
6+
7+
CSM (Conversational Speech Model) is a speech generation model from [Sesame](sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
8+
9+
A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice).
10+
11+
A hosted [HuggingFace space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation.
12+
13+
## Usage
14+
15+
Setup the repo
16+
17+
```bash
18+
git clone [email protected]:SesameAILabs/csm.git
19+
cd csm
20+
python3.10 -m venv .venv
21+
source .venv/bin/activate
22+
pip install -r requirements.txt
23+
```
24+
25+
Generate a sentence
26+
27+
```python
28+
from huggingface_hub import hf_hub_download
29+
from generator import load_csm_1b
30+
import torchaudio
31+
32+
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
33+
generator = load_csm_1b(model_path, "cuda")
34+
audio = generator.generate(
35+
text="Hello from Sesame.",
36+
speaker=0,
37+
context=[],
38+
max_audio_length_ms=10_000,
39+
)
40+
41+
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
42+
```
43+
44+
CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker utterance.
45+
46+
```python
47+
speakers = [0, 1, 0, 0]
48+
transcripts = [
49+
"Hey how are you doing.",
50+
"Pretty good, pretty good.",
51+
"I'm great.",
52+
"So happy to be speaking to you.",
53+
]
54+
audio_paths = [
55+
"utterance_0.wav",
56+
"utterance_1.wav",
57+
"utterance_2.wav",
58+
"utterance_3.wav",
59+
]
60+
61+
def load_audio(audio_path):
62+
audio_tensor, sample_rate = torchaudio.load(audio_path)
63+
audio_tensor = torchaudio.functional.resample(
64+
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
65+
)
66+
return audio_tensor
67+
68+
segments = [
69+
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
70+
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
71+
]
72+
audio = generator.generate(
73+
text="Me too, this is some cool stuff huh?",
74+
speaker=1,
75+
context=segments,
76+
max_audio_length_ms=10_000,
77+
)
78+
79+
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
80+
```
81+
82+
## FAQ
83+
84+
**Does this model come with any voices?**
85+
86+
The model open sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice.
87+
88+
**Can I converse with the model?**
89+
90+
CSM is trained to be an audio generation model and not a general purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation.
91+
92+
**Does it support other languages?**
93+
94+
The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
95+
96+
## Misuse and abuse ⚠️
97+
98+
This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
99+
100+
- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
101+
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
102+
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
103+
104+
By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
105+
106+
**Authors**
107+
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.

generator.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from dataclasses import dataclass
2+
from typing import List, Tuple
3+
4+
import torch
5+
import torchaudio
6+
from huggingface_hub import hf_hub_download
7+
from models import Model, ModelArgs
8+
from moshi.models import loaders
9+
from tokenizers.processors import TemplateProcessing
10+
from transformers import AutoTokenizer
11+
from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
12+
13+
14+
@dataclass
15+
class Segment:
16+
speaker: int
17+
text: str
18+
# (num_samples,), sample_rate = 24_000
19+
audio: torch.Tensor
20+
21+
22+
def load_llama3_tokenizer():
23+
"""
24+
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
25+
"""
26+
tokenizer_name = "meta-llama/Llama-3.2-1B"
27+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
28+
bos = tokenizer.bos_token
29+
eos = tokenizer.eos_token
30+
tokenizer._tokenizer.post_processor = TemplateProcessing(
31+
single=f"{bos}:0 $A:0 {eos}:0",
32+
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
33+
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
34+
)
35+
36+
return tokenizer
37+
38+
39+
class Generator:
40+
def __init__(
41+
self,
42+
model: Model,
43+
):
44+
self._model = model
45+
self._model.setup_caches(1)
46+
47+
self._text_tokenizer = load_llama3_tokenizer()
48+
49+
device = next(model.parameters()).device
50+
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
51+
mimi = loaders.get_mimi(mimi_weight, device=device)
52+
mimi.set_num_codebooks(32)
53+
self._audio_tokenizer = mimi
54+
55+
self._watermarker = load_watermarker(device=device)
56+
57+
self.sample_rate = mimi.sample_rate
58+
self.device = device
59+
60+
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
61+
frame_tokens = []
62+
frame_masks = []
63+
64+
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
65+
text_frame = torch.zeros(len(text_tokens), 33).long()
66+
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
67+
text_frame[:, -1] = torch.tensor(text_tokens)
68+
text_frame_mask[:, -1] = True
69+
70+
frame_tokens.append(text_frame.to(self.device))
71+
frame_masks.append(text_frame_mask.to(self.device))
72+
73+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
74+
75+
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
76+
frame_tokens = []
77+
frame_masks = []
78+
79+
# (K, T)
80+
audio = audio.to(self.device)
81+
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
82+
# add EOS frame
83+
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
84+
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
85+
86+
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
87+
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
88+
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
89+
audio_frame_mask[:, :-1] = True
90+
91+
frame_tokens.append(audio_frame)
92+
frame_masks.append(audio_frame_mask)
93+
94+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
95+
96+
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
97+
"""
98+
Returns:
99+
(seq_len, 33), (seq_len, 33)
100+
"""
101+
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
102+
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
103+
104+
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
105+
106+
@torch.inference_mode()
107+
def generate(
108+
self,
109+
text: str,
110+
speaker: int,
111+
context: List[Segment],
112+
max_audio_length_ms: float = 90_000,
113+
temperature: float = 0.9,
114+
topk: int = 50,
115+
) -> torch.Tensor:
116+
self._model.reset_caches()
117+
118+
max_audio_frames = int(max_audio_length_ms / 80)
119+
tokens, tokens_mask = [], []
120+
for segment in context:
121+
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
122+
tokens.append(segment_tokens)
123+
tokens_mask.append(segment_tokens_mask)
124+
125+
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
126+
tokens.append(gen_segment_tokens)
127+
tokens_mask.append(gen_segment_tokens_mask)
128+
129+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
130+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
131+
132+
samples = []
133+
curr_tokens = prompt_tokens.unsqueeze(0)
134+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
135+
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
136+
137+
max_seq_len = 2048 - max_audio_frames
138+
if curr_tokens.size(1) >= max_seq_len:
139+
raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
140+
141+
for _ in range(max_audio_frames):
142+
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
143+
if torch.all(sample == 0):
144+
break # eos
145+
146+
samples.append(sample)
147+
148+
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
149+
curr_tokens_mask = torch.cat(
150+
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
151+
).unsqueeze(1)
152+
curr_pos = curr_pos[:, -1:] + 1
153+
154+
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
155+
156+
# This applies an imperceptible watermark to identify audio as AI-generated.
157+
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
158+
# Please be a responsible AI citizen and keep the watermarking in place.
159+
# If using CSM 1B in another application, use your own private key and keep it secret.
160+
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
161+
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
162+
163+
return audio
164+
165+
166+
def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator:
167+
model_args = ModelArgs(
168+
backbone_flavor="llama-1B",
169+
decoder_flavor="llama-100M",
170+
text_vocab_size=128256,
171+
audio_vocab_size=2051,
172+
audio_num_codebooks=32,
173+
)
174+
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
175+
state_dict = torch.load(ckpt_path)
176+
model.load_state_dict(state_dict)
177+
178+
generator = Generator(model)
179+
return generator

0 commit comments

Comments
 (0)