|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +import torchaudio |
| 6 | +from huggingface_hub import hf_hub_download |
| 7 | +from models import Model, ModelArgs |
| 8 | +from moshi.models import loaders |
| 9 | +from tokenizers.processors import TemplateProcessing |
| 10 | +from transformers import AutoTokenizer |
| 11 | +from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark |
| 12 | + |
| 13 | + |
| 14 | +@dataclass |
| 15 | +class Segment: |
| 16 | + speaker: int |
| 17 | + text: str |
| 18 | + # (num_samples,), sample_rate = 24_000 |
| 19 | + audio: torch.Tensor |
| 20 | + |
| 21 | + |
| 22 | +def load_llama3_tokenizer(): |
| 23 | + """ |
| 24 | + https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 |
| 25 | + """ |
| 26 | + tokenizer_name = "meta-llama/Llama-3.2-1B" |
| 27 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| 28 | + bos = tokenizer.bos_token |
| 29 | + eos = tokenizer.eos_token |
| 30 | + tokenizer._tokenizer.post_processor = TemplateProcessing( |
| 31 | + single=f"{bos}:0 $A:0 {eos}:0", |
| 32 | + pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", |
| 33 | + special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], |
| 34 | + ) |
| 35 | + |
| 36 | + return tokenizer |
| 37 | + |
| 38 | + |
| 39 | +class Generator: |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + model: Model, |
| 43 | + ): |
| 44 | + self._model = model |
| 45 | + self._model.setup_caches(1) |
| 46 | + |
| 47 | + self._text_tokenizer = load_llama3_tokenizer() |
| 48 | + |
| 49 | + device = next(model.parameters()).device |
| 50 | + mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) |
| 51 | + mimi = loaders.get_mimi(mimi_weight, device=device) |
| 52 | + mimi.set_num_codebooks(32) |
| 53 | + self._audio_tokenizer = mimi |
| 54 | + |
| 55 | + self._watermarker = load_watermarker(device=device) |
| 56 | + |
| 57 | + self.sample_rate = mimi.sample_rate |
| 58 | + self.device = device |
| 59 | + |
| 60 | + def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| 61 | + frame_tokens = [] |
| 62 | + frame_masks = [] |
| 63 | + |
| 64 | + text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") |
| 65 | + text_frame = torch.zeros(len(text_tokens), 33).long() |
| 66 | + text_frame_mask = torch.zeros(len(text_tokens), 33).bool() |
| 67 | + text_frame[:, -1] = torch.tensor(text_tokens) |
| 68 | + text_frame_mask[:, -1] = True |
| 69 | + |
| 70 | + frame_tokens.append(text_frame.to(self.device)) |
| 71 | + frame_masks.append(text_frame_mask.to(self.device)) |
| 72 | + |
| 73 | + return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) |
| 74 | + |
| 75 | + def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 76 | + frame_tokens = [] |
| 77 | + frame_masks = [] |
| 78 | + |
| 79 | + # (K, T) |
| 80 | + audio = audio.to(self.device) |
| 81 | + audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] |
| 82 | + # add EOS frame |
| 83 | + eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) |
| 84 | + audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) |
| 85 | + |
| 86 | + audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device) |
| 87 | + audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device) |
| 88 | + audio_frame[:, :-1] = audio_tokens.transpose(0, 1) |
| 89 | + audio_frame_mask[:, :-1] = True |
| 90 | + |
| 91 | + frame_tokens.append(audio_frame) |
| 92 | + frame_masks.append(audio_frame_mask) |
| 93 | + |
| 94 | + return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) |
| 95 | + |
| 96 | + def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: |
| 97 | + """ |
| 98 | + Returns: |
| 99 | + (seq_len, 33), (seq_len, 33) |
| 100 | + """ |
| 101 | + text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) |
| 102 | + audio_tokens, audio_masks = self._tokenize_audio(segment.audio) |
| 103 | + |
| 104 | + return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) |
| 105 | + |
| 106 | + @torch.inference_mode() |
| 107 | + def generate( |
| 108 | + self, |
| 109 | + text: str, |
| 110 | + speaker: int, |
| 111 | + context: List[Segment], |
| 112 | + max_audio_length_ms: float = 90_000, |
| 113 | + temperature: float = 0.9, |
| 114 | + topk: int = 50, |
| 115 | + ) -> torch.Tensor: |
| 116 | + self._model.reset_caches() |
| 117 | + |
| 118 | + max_audio_frames = int(max_audio_length_ms / 80) |
| 119 | + tokens, tokens_mask = [], [] |
| 120 | + for segment in context: |
| 121 | + segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) |
| 122 | + tokens.append(segment_tokens) |
| 123 | + tokens_mask.append(segment_tokens_mask) |
| 124 | + |
| 125 | + gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker) |
| 126 | + tokens.append(gen_segment_tokens) |
| 127 | + tokens_mask.append(gen_segment_tokens_mask) |
| 128 | + |
| 129 | + prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) |
| 130 | + prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) |
| 131 | + |
| 132 | + samples = [] |
| 133 | + curr_tokens = prompt_tokens.unsqueeze(0) |
| 134 | + curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) |
| 135 | + curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) |
| 136 | + |
| 137 | + max_seq_len = 2048 - max_audio_frames |
| 138 | + if curr_tokens.size(1) >= max_seq_len: |
| 139 | + raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}") |
| 140 | + |
| 141 | + for _ in range(max_audio_frames): |
| 142 | + sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) |
| 143 | + if torch.all(sample == 0): |
| 144 | + break # eos |
| 145 | + |
| 146 | + samples.append(sample) |
| 147 | + |
| 148 | + curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) |
| 149 | + curr_tokens_mask = torch.cat( |
| 150 | + [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 |
| 151 | + ).unsqueeze(1) |
| 152 | + curr_pos = curr_pos[:, -1:] + 1 |
| 153 | + |
| 154 | + audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) |
| 155 | + |
| 156 | + # This applies an imperceptible watermark to identify audio as AI-generated. |
| 157 | + # Watermarking ensures transparency, dissuades misuse, and enables traceability. |
| 158 | + # Please be a responsible AI citizen and keep the watermarking in place. |
| 159 | + # If using CSM 1B in another application, use your own private key and keep it secret. |
| 160 | + audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) |
| 161 | + audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) |
| 162 | + |
| 163 | + return audio |
| 164 | + |
| 165 | + |
| 166 | +def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator: |
| 167 | + model_args = ModelArgs( |
| 168 | + backbone_flavor="llama-1B", |
| 169 | + decoder_flavor="llama-100M", |
| 170 | + text_vocab_size=128256, |
| 171 | + audio_vocab_size=2051, |
| 172 | + audio_num_codebooks=32, |
| 173 | + ) |
| 174 | + model = Model(model_args).to(device=device, dtype=torch.bfloat16) |
| 175 | + state_dict = torch.load(ckpt_path) |
| 176 | + model.load_state_dict(state_dict) |
| 177 | + |
| 178 | + generator = Generator(model) |
| 179 | + return generator |
0 commit comments