Skip to content

Commit f991615

Browse files
NULL204Tohrusky
andauthored
refactor: class-based design (#15)
--------- Co-authored-by: Tohrusky <[email protected]>
1 parent b67c9c6 commit f991615

12 files changed

+196
-79
lines changed

.github/workflows/CI-test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
CI:
2323
strategy:
2424
matrix:
25-
os-version: ["ubuntu-20.04", "macos-13", "windows-latest"]
25+
os-version: ["ubuntu-20.04", "windows-latest", "macos-13"]
2626
python-version: ["3.9"]
2727
poetry-version: ["1.8.3"]
2828

@@ -48,7 +48,7 @@ jobs:
4848
- name: Test
4949
run: |
5050
pip install numpy==1.26.4
51-
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2
51+
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper httpx tenacity pysubs2
5252
5353
make lint
5454
make test

README.md

+15-26
Original file line numberDiff line numberDiff line change
@@ -40,36 +40,25 @@ yuisub -h # Displays help message
4040
```python3
4141
import asyncio
4242

43-
from yuisub import translate, bilingual, load
44-
from yuisub.a2t import WhisperModel
43+
from yuisub import SubtitleTranslator
4544

46-
# use an asynchronous environment
45+
# Using an asynchronous environment
4746
async def main() -> None:
48-
49-
# sub from audio
50-
model = WhisperModel(name="medium", device="cuda")
51-
sub = model.transcribe(audio="path/to/audio.mp3")
52-
53-
# sub from file
54-
# sub = load("path/to/input.srt")
55-
56-
# generate bilingual subtitle
57-
sub_zh = await translate(
58-
sub=sub,
59-
model="gpt_model_name",
60-
api_key="your_openai_api_key",
61-
base_url="api_url",
62-
bangumi_url="https://bangumi.tv/subject/424883/"
63-
)
64-
65-
sub_bilingual = await bilingual(
66-
sub_origin=sub,
67-
sub_zh=sub_zh
47+
translator = SubtitleTranslator(
48+
# if you wanna use audio input
49+
# torch_device='cuda',
50+
# whisper_model='medium',
51+
52+
model='gpt_model_name',
53+
api_key='your_openai_api_key',
54+
base_url='api_url',
55+
bangumi_url='https://bangumi.tv/subject/424883/',
56+
bangumi_access_token='your_bangumi_token',
6857
)
6958

70-
# save the ASS files
71-
sub_zh.save("path/to/output.zh.ass")
72-
sub_bilingual.save("path/to/output.bilingual.ass")
59+
sub_zh, sub_bilingual = await translator.get_subtitles(sub='path/to/sub.srt') # Or audio='path/to/audio.mp3',
60+
sub_zh.save('path/to/output_zh.ass')
61+
sub_bilingual.save('path/to/output_bilingual.ass')
7362

7463
asyncio.run(main())
7564
```

tests/test_bangumi.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from yuisub import bangumi
22

3+
from . import util
4+
35

46
async def test_bangumi() -> None:
57
url_list = [
@@ -9,6 +11,6 @@ async def test_bangumi() -> None:
911
]
1012

1113
for url in url_list:
12-
r = await bangumi(url)
14+
r = await bangumi(url=url, token=util.BANGUMI_ACCESS_TOKEN)
1315
print(r.introduction)
1416
print(r.characters)

tests/test_llm.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44

5-
from tests import util
65
from yuisub import ORIGIN, Summarizer, Translator, bangumi
76

7+
from . import util
8+
89
origin = ORIGIN(
910
origin="何だよ…けっこう多いじゃねぇか",
1011
)
@@ -65,7 +66,7 @@ async def test_llm_bangumi() -> None:
6566
model=util.OPENAI_MODEL,
6667
api_key=util.OPENAI_API_KEY,
6768
base_url=util.OPENAI_BASE_URL,
68-
bangumi_info=await bangumi(util.BANGUMI_URL),
69+
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
6970
)
7071
print(t.system_prompt)
7172
res = await t.ask(origin)
@@ -78,7 +79,7 @@ async def test_llm_bangumi_2() -> None:
7879
model=util.OPENAI_MODEL,
7980
api_key=util.OPENAI_API_KEY,
8081
base_url=util.OPENAI_BASE_URL,
81-
bangumi_info=await bangumi(util.BANGUMI_URL),
82+
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
8283
)
8384
print(t.system_prompt)
8485
s = ORIGIN(
@@ -95,7 +96,7 @@ async def test_llm_summary() -> None:
9596
model=util.OPENAI_MODEL,
9697
api_key=util.OPENAI_API_KEY,
9798
base_url=util.OPENAI_BASE_URL,
98-
bangumi_info=await bangumi(util.BANGUMI_URL),
99+
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
99100
)
100101
print(t.system_prompt)
101102
res = await t.ask(summary_origin)

tests/test_sub.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import pytest
44

5-
from tests import util
65
from yuisub.a2t import WhisperModel
76
from yuisub.sub import bilingual, load, translate
87

8+
from . import util
9+
910

1011
def test_sub() -> None:
1112
sub = load(util.TEST_ENG_SRT)
@@ -34,6 +35,7 @@ async def test_bilingual_2() -> None:
3435
api_key=util.OPENAI_API_KEY,
3536
base_url=util.OPENAI_BASE_URL,
3637
bangumi_url=util.BANGUMI_URL,
38+
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
3739
)
3840
sub_bilingual = await bilingual(sub_origin=sub, sub_zh=sub_zh)
3941

tests/test_translator.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
import pytest
4+
5+
from yuisub.translator import SubtitleTranslator
6+
7+
from . import util
8+
9+
10+
@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
11+
async def test_translator_sub() -> None:
12+
translator = SubtitleTranslator(
13+
model=util.OPENAI_MODEL,
14+
api_key=util.OPENAI_API_KEY,
15+
base_url=util.OPENAI_BASE_URL,
16+
bangumi_url=util.BANGUMI_URL,
17+
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
18+
)
19+
20+
sub_zh, sub_bilingual = await translator.get_subtitles(sub=str(util.TEST_ENG_SRT))
21+
sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.sub.ass")
22+
sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.sub.ass")
23+
24+
25+
@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
26+
async def test_translator_audio() -> None:
27+
translator = SubtitleTranslator(
28+
torch_device=util.DEVICE,
29+
whisper_model=util.MODEL_NAME,
30+
model=util.OPENAI_MODEL,
31+
api_key=util.OPENAI_API_KEY,
32+
base_url=util.OPENAI_BASE_URL,
33+
bangumi_url=util.BANGUMI_URL,
34+
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
35+
)
36+
37+
sub_zh, sub_bilingual = await translator.get_subtitles(audio=str(util.TEST_AUDIO))
38+
sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.audio.ass")
39+
sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.audio.ass")

tests/util.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import os
22
from pathlib import Path
33

4-
import torch
5-
64
projectPATH = Path(__file__).resolve().parent.parent.absolute()
75

86
TEST_AUDIO = projectPATH / "assets" / "test.mp3"
97
TEST_ENG_SRT = projectPATH / "assets" / "eng.srt"
108

11-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9+
DEVICE = "cpu" if os.environ.get("GITHUB_ACTIONS") == "true" else None
1210
MODEL_NAME = "medium" if DEVICE == "cuda" else "tiny"
1311

1412
BANGUMI_URL = "https://bangumi.tv/subject/424883"
13+
BANGUMI_ACCESS_TOKEN = ""
1514

1615
OPENAI_MODEL = str(os.getenv("OPENAI_MODEL")) if os.getenv("OPENAI_MODEL") else "deepseek-chat"
1716
OPENAI_BASE_URL = str(os.getenv("OPENAI_BASE_URL")) if os.getenv("OPENAI_BASE_URL") else "https://api.deepseek.com"

yuisub/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from yuisub.llm import Summarizer, Translator # noqa: F401
33
from yuisub.prompt import ORIGIN, ZH # noqa: F401
44
from yuisub.sub import advertisement, bilingual, load, translate # noqa: F401
5+
from yuisub.translator import SubtitleTranslator # noqa: F401

yuisub/__main__.py

+17-40
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
import argparse
22
import asyncio
3-
import sys
43

5-
from yuisub.sub import bilingual, load, translate
4+
from yuisub import SubtitleTranslator
65

7-
# ffmpeg -i test.mkv -c:a mp3 -map 0:a:0 test.mp3
8-
# ffmpeg -i test.mkv -map 0:s:0 eng.srt
6+
parser = argparse.ArgumentParser(description="Generate Bilingual Subtitle from audio or subtitle file")
97

10-
parser = argparse.ArgumentParser()
11-
parser.description = "Generate Bilingual Subtitle from audio or subtitle file"
12-
# input
8+
# Input
139
parser.add_argument("-a", "--AUDIO", type=str, help="Path to the audio file", required=False)
1410
parser.add_argument("-s", "--SUB", type=str, help="Path to the input Subtitle file", required=False)
15-
# subtitle output
11+
# Output
1612
parser.add_argument("-oz", "--OUTPUT_ZH", type=str, help="Path to save the Chinese ASS file", required=False)
1713
parser.add_argument("-ob", "--OUTPUT_BILINGUAL", type=str, help="Path to save the bilingual ASS file", required=False)
18-
# openai gpt
14+
# OpenAI GPT
1915
parser.add_argument("-om", "--OPENAI_MODEL", type=str, help="Openai model name", required=True)
2016
parser.add_argument("-api", "--OPENAI_API_KEY", type=str, help="Openai API key", required=True)
2117
parser.add_argument("-url", "--OPENAI_BASE_URL", type=str, help="Openai base URL", required=True)
22-
# bangumi
18+
# Bangumi
2319
parser.add_argument("-bgm", "--BANGUMI_URL", type=str, help="Anime Bangumi URL", required=False)
2420
parser.add_argument("-ac", "--BANGUMI_ACCESS_TOKEN", type=str, help="Anime Bangumi Access Token", required=False)
25-
# whisper
21+
# Whisper
2622
parser.add_argument("-d", "--TORCH_DEVICE", type=str, help="Pytorch device to use", required=False)
2723
parser.add_argument("-wm", "--WHISPER_MODEL", type=str, help="Whisper model to use", required=False)
2824

@@ -33,47 +29,28 @@ async def main() -> None:
3329
if args.AUDIO and args.SUB:
3430
raise ValueError("Please provide only one input file, either audio or subtitle file")
3531

32+
if not args.AUDIO and not args.SUB:
33+
raise ValueError("Please provide an input file, either audio or subtitle file")
34+
3635
if not args.OUTPUT_ZH and not args.OUTPUT_BILINGUAL:
3736
raise ValueError("Please provide output paths for the subtitles.")
3837

39-
if args.AUDIO:
40-
import torch
41-
42-
from yuisub.a2t import WhisperModel
43-
44-
if args.TORCH_DEVICE:
45-
_DEVICE = args.TORCH_DEVICE
46-
else:
47-
_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
48-
if sys.platform == "darwin":
49-
_DEVICE = "mps"
50-
51-
if args.WHISPER_MODEL:
52-
_MODEL = args.WHISPER_MODEL
53-
else:
54-
_MODEL = "medium" if _DEVICE == "cpu" else "large-v2"
55-
56-
model = WhisperModel(name=_MODEL, device=_DEVICE)
57-
58-
sub = model.transcribe(audio=args.AUDIO)
59-
60-
else:
61-
sub = load(args.SUB)
62-
63-
sub_zh = await translate(
64-
sub=sub,
38+
translator = SubtitleTranslator(
6539
model=args.OPENAI_MODEL,
6640
api_key=args.OPENAI_API_KEY,
6741
base_url=args.OPENAI_BASE_URL,
6842
bangumi_url=args.BANGUMI_URL,
6943
bangumi_access_token=args.BANGUMI_ACCESS_TOKEN,
44+
torch_device=args.TORCH_DEVICE,
45+
whisper_model=args.WHISPER_MODEL,
7046
)
7147

72-
sub_bilingual = await bilingual(sub_origin=sub, sub_zh=sub_zh)
73-
48+
sub_zh, sub_bilingual = await translator.get_subtitles(
49+
sub=args.SUB,
50+
audio=args.AUDIO,
51+
)
7452
if args.OUTPUT_ZH:
7553
sub_zh.save(args.OUTPUT_ZH)
76-
7754
if args.OUTPUT_BILINGUAL:
7855
sub_bilingual.save(args.OUTPUT_BILINGUAL)
7956

yuisub/a2t.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
class WhisperModel:
1111
def __init__(
12-
self, name: str = "medium", device: str = "cuda", download_root: Optional[str] = None, in_memory: bool = False
12+
self,
13+
name: str = "medium",
14+
device: Optional[Union[str, torch.device]] = None,
15+
download_root: Optional[str] = None,
16+
in_memory: bool = False,
1317
):
1418
self.model = whisper.load_model(name=name, device=device, download_root=download_root, in_memory=in_memory)
1519

yuisub/sub.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ async def translate(
112112
base_url=base_url,
113113
bangumi_info=bangumi_info,
114114
)
115+
print(summarizer.system_prompt)
115116

116-
print("Summarizing...")
117117
# get summary
118118
summary = await summarizer.ask(ORIGIN(origin="\n".join(trans_list)))
119119

0 commit comments

Comments
 (0)