forked from SunoAI-API/Suno-API
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
162 lines (131 loc) · 4.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -*- coding:utf-8 -*-
import json
import time
import traceback
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
import schemas
from deps import get_token
from utils import generate_lyrics, generate_music, get_feed, get_lyrics, get_credits
from cookie import suno_auth, start_keep_alive
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def get_root():
return schemas.Response()
@app.post("/generate")
async def generate(
data: schemas.CustomModeGenerateParam, token: str = Depends(get_token)
):
try:
resp = await generate_music(data.dict(), token)
return resp
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post("/generate/description-mode")
async def generate_with_song_description(
data: schemas.DescriptionModeGenerateParam, token: str = Depends(get_token)
):
max_retries = len(suno_auth.account_manager.active_accounts)
retry_count = 0
while retry_count < max_retries:
try:
# 首先检查账户积分
credits_info = await get_credits(token)
if credits_info["credits_left"] == 0:
# 积分不足,切换到下一个账户并禁用当前账户
suno_auth.handle_insufficient_credits()
time.sleep(1)
token = suno_auth.get_token()
retry_count += 1
continue
# 有足够积分,进行生成
resp = await generate_music(data.dict(), token)
# 检查是否有太多运行中的任务
if isinstance(resp, dict) and resp.get("detail") == "Too many running jobs.":
# 任务太多时只切换账户,不禁用
suno_auth.load_next_account()
time.sleep(1)
token = suno_auth.get_token()
retry_count += 1
continue
return resp
except Exception as e:
traceback.print_exc()
retry_count += 1
if retry_count >= max_retries:
raise HTTPException(
detail="All accounts exhausted or error occurred",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
suno_auth.handle_insufficient_credits()
time.sleep(1)
token = suno_auth.get_token()
continue
raise HTTPException(
detail="All accounts exhausted",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/feed/{aid}")
async def fetch_feed(aid: str, token: str = Depends(get_token)):
try:
resp = await get_feed(aid, token)
return resp
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post("/generate/lyrics/")
async def generate_lyrics_post(request: Request, token: str = Depends(get_token)):
req = await request.json()
prompt = req.get("prompt")
if prompt is None:
raise HTTPException(
detail="prompt is required", status_code=status.HTTP_400_BAD_REQUEST
)
try:
resp = await generate_lyrics(prompt, token)
return resp
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/lyrics/{lid}")
async def fetch_lyrics(lid: str, token: str = Depends(get_token)):
try:
resp = await get_lyrics(lid, token)
return resp
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/get_credits")
async def fetch_credits(token: str = Depends(get_token)):
try:
resp = await get_credits(token)
return resp
except Exception as e:
traceback.print_exc()
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.on_event("startup")
async def startup_event():
print("Starting application...")
# 重新加载账号状态
suno_auth.account_manager.load_accounts()
suno_auth.account_manager.load_disabled_accounts()
suno_auth.account_manager.update_active_accounts()
print(f"Loaded {len(suno_auth.account_manager.accounts)} total accounts")
print(f"Found {len(suno_auth.account_manager.disabled_accounts)} disabled accounts")
print(f"Active accounts: {len(suno_auth.account_manager.active_accounts)}")
# 启动keep_alive
start_keep_alive(suno_auth)