408 lines
13 KiB
Python
408 lines
13 KiB
Python
import threading, queue, time, os, json
|
|
import torch, torchaudio
|
|
import numpy as np
|
|
import sounddevice as sd
|
|
from einops import rearrange
|
|
from stable_audio_tools import get_pretrained_model
|
|
from stable_audio_tools.inference.generation import generate_diffusion_cond
|
|
from pydub import AudioSegment
|
|
import webview
|
|
|
|
SAVE_DIR = "clips"; os.makedirs(SAVE_DIR, exist_ok=True)
|
|
FADE_MS = 3000
|
|
DEFAULT_PROMPT = "cozy medieval tavern ambience with flute, harp and soft percussion, hearthstone, folk, peaceful forest cabin melody with lute and gentle wind chimes, warm and safe, serene elven inn with light lyre and soft strings, magical and relaxing, evening at a rustic inn, lute and low flute with crackling fire village square at dusk, medieval folk instruments and background chatter, ocarina of time, folk, pagan, percussion, congas, drums"
|
|
|
|
# — Shared state & buffer —
|
|
state = {
|
|
"playing": False,
|
|
"prompt": DEFAULT_PROMPT,
|
|
"spinner": False,
|
|
"volume": 80,
|
|
"need_ui_update": False,
|
|
"resume_np": None,
|
|
"resume_sr": None,
|
|
"resume_idx": 0,
|
|
"model_ready": False,
|
|
"record": True,
|
|
}
|
|
AUDIO_QUEUE = queue.Queue(maxsize=3)
|
|
model = None
|
|
config = None
|
|
|
|
# — Model loading in the background —
|
|
def load_model():
|
|
global model, config
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
model_, config_ = get_pretrained_model("stabilityai/stable-audio-open-small")
|
|
model_ = model_.to(device); model_.half()
|
|
model = model_
|
|
config = config_
|
|
state["model_ready"] = True
|
|
state["need_ui_update"] = True
|
|
|
|
# — Audio generation helper —
|
|
def generate_and_save(prompt, duration=11):
|
|
while not state["model_ready"]:
|
|
time.sleep(0.1)
|
|
prompt = prompt.strip() or DEFAULT_PROMPT
|
|
tensor = generate_diffusion_cond(
|
|
model=model, steps=40, cfg_scale=5,
|
|
conditioning=[{"prompt": prompt, "seconds_start": 0, "seconds_total": duration}],
|
|
sample_size=config["sample_size"], sigma_min=0.3, sigma_max=500,
|
|
sampler_type="euler", device=("mps" if torch.backends.mps.is_available() else "cpu")
|
|
)
|
|
tensor = rearrange(tensor, "b d n -> d (b n)").float()
|
|
tensor = tensor / tensor.abs().max()
|
|
tensor = (tensor * 32767).to(torch.int16).cpu()
|
|
if state["record"]:
|
|
fn = os.path.join(SAVE_DIR, f"clip_{int(torch.rand(1).item()*1e9)}.wav")
|
|
torchaudio.save(fn, tensor, config["sample_rate"])
|
|
return AudioSegment.from_file(fn)
|
|
else:
|
|
import io
|
|
buf = io.BytesIO()
|
|
torchaudio.save(buf, tensor, config["sample_rate"], format="wav")
|
|
buf.seek(0)
|
|
return AudioSegment.from_file(buf, format="wav")
|
|
|
|
def segment_to_np(seg):
|
|
arr = np.array(seg.get_array_of_samples(), dtype=np.float32)
|
|
arr = arr.reshape(-1, seg.channels) / 32768.0
|
|
return arr, seg.frame_rate
|
|
|
|
# — Playback thread —
|
|
def playback_worker():
|
|
blocksize = 1024
|
|
stream = sd.OutputStream(samplerate=44100, channels=2, dtype='float32', blocksize=blocksize)
|
|
stream.start()
|
|
while True:
|
|
while not state["playing"] or not state["model_ready"]:
|
|
time.sleep(0.05)
|
|
if state["resume_np"] is not None:
|
|
arr, sr, idx = state["resume_np"], state["resume_sr"], state["resume_idx"]
|
|
state["resume_np"] = state["resume_sr"] = state["resume_idx"] = None
|
|
else:
|
|
arr, sr = AUDIO_QUEUE.get()
|
|
idx = 0
|
|
state["spinner"] = False
|
|
state["need_ui_update"] = True
|
|
while idx < arr.shape[0] and state["playing"]:
|
|
v = state["volume"] / 80
|
|
chunk = arr[idx:idx+blocksize] * v
|
|
if chunk.shape[1] != 2:
|
|
chunk = np.tile(chunk, (1,2))
|
|
stream.write(chunk)
|
|
idx += blocksize
|
|
state["resume_np"], state["resume_sr"], state["resume_idx"] = arr, sr, idx
|
|
if idx >= arr.shape[0]:
|
|
state["resume_np"] = state["resume_sr"] = state["resume_idx"] = None
|
|
|
|
# — Generator thread —
|
|
def generator_worker():
|
|
seg1 = seg2 = None
|
|
while True:
|
|
if not state["playing"] or not state["model_ready"]:
|
|
time.sleep(0.1); continue
|
|
if AUDIO_QUEUE.full():
|
|
time.sleep(0.05); continue
|
|
prompt = state["prompt"].strip() or DEFAULT_PROMPT
|
|
if seg1 is None:
|
|
seg1 = generate_and_save(prompt)
|
|
seg2 = generate_and_save(prompt)
|
|
cut = len(seg1) - FADE_MS
|
|
if cut > 0: AUDIO_QUEUE.put(segment_to_np(seg1[:cut]))
|
|
cross = seg1[cut:].append(seg2[:FADE_MS], crossfade=FADE_MS) if cut > 0 else seg2[:FADE_MS]
|
|
AUDIO_QUEUE.put(segment_to_np(cross))
|
|
next_seg = [None]
|
|
t = threading.Thread(target=lambda: next_seg.__setitem__(0, generate_and_save(state["prompt"].strip() or DEFAULT_PROMPT)))
|
|
t.start()
|
|
cut2 = len(seg2) - FADE_MS
|
|
if cut2 > 0: AUDIO_QUEUE.put(segment_to_np(seg2[FADE_MS:cut2]))
|
|
t.join()
|
|
seg1, seg2 = seg2, next_seg[0]
|
|
if seg1 is None or seg2 is None:
|
|
continue
|
|
|
|
# — GUI sync thread —
|
|
def gui_callback_loop(win):
|
|
last = (None,None,None,None,None,None)
|
|
while True:
|
|
curr = (state["spinner"], state["playing"], state["prompt"], state["volume"], state["model_ready"], state["record"])
|
|
if curr != last or state["need_ui_update"]:
|
|
js_prompt = json.dumps(state["prompt"])
|
|
win.evaluate_js(
|
|
f"updateUI({str(state['playing']).lower()},"
|
|
f"{str(state['spinner']).lower()},"
|
|
f"{js_prompt},"
|
|
f"{state['volume']},"
|
|
f"{str(state['model_ready']).lower()},"
|
|
f"{str(state['record']).lower()})"
|
|
)
|
|
state["need_ui_update"] = False
|
|
last = curr
|
|
time.sleep(0.1)
|
|
|
|
# — JS API —
|
|
class API:
|
|
def playpause(self):
|
|
if not state["model_ready"]:
|
|
return False
|
|
state["playing"] = not state["playing"]
|
|
if state["playing"] and state["resume_np"] is None and AUDIO_QUEUE.empty():
|
|
state["spinner"] = True
|
|
else:
|
|
state["spinner"] = False
|
|
state["need_ui_update"] = True
|
|
if not state["playing"]:
|
|
sd.stop()
|
|
return state["playing"]
|
|
def is_playing(self): return state["playing"]
|
|
def spinner(self): return state["spinner"]
|
|
def set_prompt(self, v):
|
|
state.update(prompt=v, need_ui_update=True)
|
|
def get_prompt(self): return state["prompt"]
|
|
def set_volume(self, v): state.update(volume=int(v), need_ui_update=True)
|
|
def get_volume(self): return state["volume"]
|
|
def model_ready(self): return state["model_ready"]
|
|
def get_record(self): return state["record"]
|
|
def set_record(self, value):
|
|
state["record"] = bool(value)
|
|
state["need_ui_update"] = True
|
|
|
|
# — HTML front-end —
|
|
HTML = """<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<style>
|
|
html, body {
|
|
height: 100%; width: 100%;
|
|
margin: 0; padding: 0;
|
|
background: #1a1208;
|
|
color: #e8dbc2;
|
|
font-family: sans-serif;
|
|
overflow: hidden;
|
|
}
|
|
#container {
|
|
width: 100%; height: 100%;
|
|
box-sizing: border-box;
|
|
display: flex;
|
|
flex-direction: column;
|
|
height: 100vh;
|
|
}
|
|
#prompt {
|
|
width: calc(100% - 24px);
|
|
height: calc(100% - 64px);
|
|
min-height: 44px;
|
|
font-size: 13px;
|
|
margin: 12px 12px 0 12px;
|
|
background: #e8dbc2;
|
|
color: #2c2113;
|
|
border: none;
|
|
border-radius: 4px;
|
|
resize: none;
|
|
padding: 5px;
|
|
box-sizing: border-box;
|
|
}
|
|
#prompt::placeholder { color: #a08663; }
|
|
.controls {
|
|
width: 100%;
|
|
box-sizing: border-box;
|
|
position: absolute;
|
|
left: 0; bottom: 8px;
|
|
display: flex;
|
|
justify-content: center;
|
|
align-items: flex-end;
|
|
pointer-events: auto;
|
|
}
|
|
.inner-controls {
|
|
display: grid;
|
|
grid-template-columns: 38px 46px 68px;
|
|
gap: 10px;
|
|
align-items: center;
|
|
justify-items: center;
|
|
background: none;
|
|
padding: 0 0 0 0;
|
|
}
|
|
.iconbtn {
|
|
display: flex;
|
|
align-items: center;
|
|
justify-content: center;
|
|
width: 38px; height: 38px;
|
|
background: none;
|
|
border: none;
|
|
outline: none;
|
|
cursor: pointer;
|
|
padding: 0;
|
|
margin: 0;
|
|
}
|
|
#record-dot {
|
|
width: 18px; height: 18px;
|
|
border-radius: 50%;
|
|
background: #c94040;
|
|
border: 2px solid #ac2323;
|
|
box-shadow: 0 0 2px #ac2323;
|
|
transition: background 0.2s, border-color 0.2s;
|
|
}
|
|
#record-dot.off {
|
|
background: #2c2113;
|
|
border-color: #665050;
|
|
box-shadow: none;
|
|
}
|
|
#record:focus { outline: none; }
|
|
.playcenter {
|
|
display: flex;
|
|
align-items: center;
|
|
justify-content: center;
|
|
gap: 6px;
|
|
}
|
|
#spinbox {
|
|
display: flex;
|
|
align-items: center;
|
|
justify-content: center;
|
|
margin-left: 0;
|
|
}
|
|
.spinner {
|
|
width: 22px; height: 22px;
|
|
border: 4px solid #ab865b;
|
|
border-top: 4px solid #e8dbc2;
|
|
border-radius: 50%;
|
|
animation: spin .7s linear infinite;
|
|
}
|
|
@keyframes spin { 100% { transform: rotate(360deg); } }
|
|
#volume {
|
|
width: 60px;
|
|
margin-left: 6px;
|
|
vertical-align: middle;
|
|
accent-color: #e8dbc2; /* For Chrome, Edge, Firefox */
|
|
}
|
|
input[type="range"]::-webkit-slider-thumb {
|
|
margin-top: -6px;
|
|
background: #e8dbc2;
|
|
}
|
|
input[type="range"]::-webkit-slider-runnable-track {
|
|
background: #a08663;
|
|
height: 4px;
|
|
border-radius: 2px;
|
|
}
|
|
input[type="range"]::-moz-range-thumb {
|
|
background: #e8dbc2;
|
|
}
|
|
input[type="range"]::-moz-range-track {
|
|
background: #a08663;
|
|
height: 4px;
|
|
border-radius: 2px;
|
|
}
|
|
input[type="range"]::-ms-thumb {
|
|
background: #e8dbc2;
|
|
}
|
|
input[type="range"]::-ms-fill-lower, input[type="range"]::-ms-fill-upper {
|
|
background: #a08663;
|
|
border-radius: 2px;
|
|
}
|
|
input[type="range"]:focus {
|
|
outline: none;
|
|
}
|
|
#model_loading_overlay{
|
|
position:fixed;left:0;top:0;width:100vw;height:100vh;z-index:100;
|
|
background:#1a1208ee;display:flex;flex-direction:column;align-items:center;justify-content:center;
|
|
}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div id="container">
|
|
<textarea id="prompt" rows="2" placeholder="Describe your tavern music..."></textarea>
|
|
<div class="controls" id="controls">
|
|
<div class="inner-controls">
|
|
<button id="record" class="iconbtn" title="Toggle saving clips">
|
|
<div id="record-dot"></div>
|
|
</button>
|
|
<div class="playcenter">
|
|
<button id="playpause" class="iconbtn" aria-label="Play/Pause">
|
|
<svg id="playpause-icon" width="26" height="26" viewBox="0 0 26 26"></svg>
|
|
</button>
|
|
<div id="spinbox" style="display:none"><div class="spinner"></div></div>
|
|
</div>
|
|
<input id="volume" type="range" min="0" max="100" value="80" title="Volume">
|
|
</div>
|
|
</div>
|
|
</div>
|
|
<div id="model_loading_overlay">
|
|
<div class="spinner"></div>
|
|
<div style="margin-top:16px;color:#e8dbc2;font-size:16px;">Loading AI Model…</div>
|
|
</div>
|
|
<script>
|
|
let updating = false;
|
|
const promptBox = document.getElementById('prompt');
|
|
let recordBtn = document.getElementById('record');
|
|
let recordDot = document.getElementById('record-dot');
|
|
let playpauseBtn = document.getElementById('playpause');
|
|
let playpauseIcon = document.getElementById('playpause-icon');
|
|
let spinbox = document.getElementById('spinbox');
|
|
let volumeSlider = document.getElementById('volume');
|
|
|
|
function setPlayIcon(isPlaying) {
|
|
playpauseIcon.innerHTML = "";
|
|
if(isPlaying){
|
|
playpauseIcon.innerHTML = `<rect x="4" y="4" width="6" height="18" rx="2" fill="#fff"/><rect x="16" y="4" width="6" height="18" rx="2" fill="#fff"/>`;
|
|
}else{
|
|
playpauseIcon.innerHTML = `<polygon points="6,4 22,13 6,22" fill="#fff"/>`;
|
|
}
|
|
}
|
|
function updateRecordUI(isOn) {
|
|
if(isOn){
|
|
recordDot.classList.remove('off');
|
|
} else {
|
|
recordDot.classList.add('off');
|
|
}
|
|
}
|
|
recordBtn.onclick = async function() {
|
|
let newVal = !(await pywebview.api.get_record());
|
|
await pywebview.api.set_record(newVal);
|
|
updateRecordUI(newVal);
|
|
};
|
|
promptBox.oninput = function() {
|
|
if (updating) return;
|
|
pywebview.api.set_prompt(this.value);
|
|
};
|
|
playpauseBtn.onclick = ()=>pywebview.api.playpause();
|
|
volumeSlider.oninput = ()=>!updating&&pywebview.api.set_volume(volumeSlider.value);
|
|
|
|
async function updateUI(p, s, prompt, vol, modelReady, recording){
|
|
updating=true;
|
|
playpauseBtn.style.display=s?'none':'';
|
|
spinbox.style.display=s?'':'none';
|
|
setPlayIcon(p);
|
|
if(document.activeElement !== promptBox && promptBox.value.length === 0)
|
|
promptBox.value=prompt;
|
|
volumeSlider.value=vol;
|
|
document.getElementById('model_loading_overlay').style.display = modelReady ? 'none' : 'flex';
|
|
updateRecordUI(recording);
|
|
updating=false;
|
|
}
|
|
window.onload=async()=>{
|
|
promptBox.value=await pywebview.api.get_prompt();
|
|
volumeSlider.value=await pywebview.api.get_volume();
|
|
pywebview.api.model_ready().then(ready=>{
|
|
document.getElementById('model_loading_overlay').style.display=ready?'none':'flex';
|
|
});
|
|
let rec = await pywebview.api.get_record();
|
|
updateRecordUI(rec);
|
|
setPlayIcon(false);
|
|
};
|
|
</script>
|
|
</body>
|
|
</html>"""
|
|
|
|
# — Launch GUI & threads —
|
|
def gui():
|
|
api=API()
|
|
win=webview.create_window("Tavern Generator", html=HTML,
|
|
width=350,height=200,js_api=api,resizable=False)
|
|
threading.Thread(target=gui_callback_loop, args=(win,), daemon=True).start()
|
|
threading.Thread(target=playback_worker, daemon=True).start()
|
|
threading.Thread(target=generator_worker, daemon=True).start()
|
|
webview.start(debug=False)
|
|
|
|
if __name__=="__main__":
|
|
threading.Thread(target=load_model, daemon=True).start()
|
|
gui() |