Files
infinite-sound/infinite-sound.py
2025-07-18 01:12:16 +02:00

408 lines
13 KiB
Python

import threading, queue, time, os, json
import torch, torchaudio
import numpy as np
import sounddevice as sd
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from pydub import AudioSegment
import webview
SAVE_DIR = "clips"; os.makedirs(SAVE_DIR, exist_ok=True)
FADE_MS = 3000
DEFAULT_PROMPT = "cozy medieval tavern ambience with flute, harp and soft percussion, hearthstone, folk, peaceful forest cabin melody with lute and gentle wind chimes, warm and safe, serene elven inn with light lyre and soft strings, magical and relaxing, evening at a rustic inn, lute and low flute with crackling fire village square at dusk, medieval folk instruments and background chatter, ocarina of time, folk, pagan, percussion, congas, drums"
# — Shared state & buffer —
state = {
"playing": False,
"prompt": DEFAULT_PROMPT,
"spinner": False,
"volume": 80,
"need_ui_update": False,
"resume_np": None,
"resume_sr": None,
"resume_idx": 0,
"model_ready": False,
"record": True,
}
AUDIO_QUEUE = queue.Queue(maxsize=3)
model = None
config = None
# — Model loading in the background —
def load_model():
global model, config
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_, config_ = get_pretrained_model("stabilityai/stable-audio-open-small")
model_ = model_.to(device); model_.half()
model = model_
config = config_
state["model_ready"] = True
state["need_ui_update"] = True
# — Audio generation helper —
def generate_and_save(prompt, duration=11):
while not state["model_ready"]:
time.sleep(0.1)
prompt = prompt.strip() or DEFAULT_PROMPT
tensor = generate_diffusion_cond(
model=model, steps=40, cfg_scale=5,
conditioning=[{"prompt": prompt, "seconds_start": 0, "seconds_total": duration}],
sample_size=config["sample_size"], sigma_min=0.3, sigma_max=500,
sampler_type="euler", device=("mps" if torch.backends.mps.is_available() else "cpu")
)
tensor = rearrange(tensor, "b d n -> d (b n)").float()
tensor = tensor / tensor.abs().max()
tensor = (tensor * 32767).to(torch.int16).cpu()
if state["record"]:
fn = os.path.join(SAVE_DIR, f"clip_{int(torch.rand(1).item()*1e9)}.wav")
torchaudio.save(fn, tensor, config["sample_rate"])
return AudioSegment.from_file(fn)
else:
import io
buf = io.BytesIO()
torchaudio.save(buf, tensor, config["sample_rate"], format="wav")
buf.seek(0)
return AudioSegment.from_file(buf, format="wav")
def segment_to_np(seg):
arr = np.array(seg.get_array_of_samples(), dtype=np.float32)
arr = arr.reshape(-1, seg.channels) / 32768.0
return arr, seg.frame_rate
# — Playback thread —
def playback_worker():
blocksize = 1024
stream = sd.OutputStream(samplerate=44100, channels=2, dtype='float32', blocksize=blocksize)
stream.start()
while True:
while not state["playing"] or not state["model_ready"]:
time.sleep(0.05)
if state["resume_np"] is not None:
arr, sr, idx = state["resume_np"], state["resume_sr"], state["resume_idx"]
state["resume_np"] = state["resume_sr"] = state["resume_idx"] = None
else:
arr, sr = AUDIO_QUEUE.get()
idx = 0
state["spinner"] = False
state["need_ui_update"] = True
while idx < arr.shape[0] and state["playing"]:
v = state["volume"] / 80
chunk = arr[idx:idx+blocksize] * v
if chunk.shape[1] != 2:
chunk = np.tile(chunk, (1,2))
stream.write(chunk)
idx += blocksize
state["resume_np"], state["resume_sr"], state["resume_idx"] = arr, sr, idx
if idx >= arr.shape[0]:
state["resume_np"] = state["resume_sr"] = state["resume_idx"] = None
# — Generator thread —
def generator_worker():
seg1 = seg2 = None
while True:
if not state["playing"] or not state["model_ready"]:
time.sleep(0.1); continue
if AUDIO_QUEUE.full():
time.sleep(0.05); continue
prompt = state["prompt"].strip() or DEFAULT_PROMPT
if seg1 is None:
seg1 = generate_and_save(prompt)
seg2 = generate_and_save(prompt)
cut = len(seg1) - FADE_MS
if cut > 0: AUDIO_QUEUE.put(segment_to_np(seg1[:cut]))
cross = seg1[cut:].append(seg2[:FADE_MS], crossfade=FADE_MS) if cut > 0 else seg2[:FADE_MS]
AUDIO_QUEUE.put(segment_to_np(cross))
next_seg = [None]
t = threading.Thread(target=lambda: next_seg.__setitem__(0, generate_and_save(state["prompt"].strip() or DEFAULT_PROMPT)))
t.start()
cut2 = len(seg2) - FADE_MS
if cut2 > 0: AUDIO_QUEUE.put(segment_to_np(seg2[FADE_MS:cut2]))
t.join()
seg1, seg2 = seg2, next_seg[0]
if seg1 is None or seg2 is None:
continue
# — GUI sync thread —
def gui_callback_loop(win):
last = (None,None,None,None,None,None)
while True:
curr = (state["spinner"], state["playing"], state["prompt"], state["volume"], state["model_ready"], state["record"])
if curr != last or state["need_ui_update"]:
js_prompt = json.dumps(state["prompt"])
win.evaluate_js(
f"updateUI({str(state['playing']).lower()},"
f"{str(state['spinner']).lower()},"
f"{js_prompt},"
f"{state['volume']},"
f"{str(state['model_ready']).lower()},"
f"{str(state['record']).lower()})"
)
state["need_ui_update"] = False
last = curr
time.sleep(0.1)
# — JS API —
class API:
def playpause(self):
if not state["model_ready"]:
return False
state["playing"] = not state["playing"]
if state["playing"] and state["resume_np"] is None and AUDIO_QUEUE.empty():
state["spinner"] = True
else:
state["spinner"] = False
state["need_ui_update"] = True
if not state["playing"]:
sd.stop()
return state["playing"]
def is_playing(self): return state["playing"]
def spinner(self): return state["spinner"]
def set_prompt(self, v):
state.update(prompt=v, need_ui_update=True)
def get_prompt(self): return state["prompt"]
def set_volume(self, v): state.update(volume=int(v), need_ui_update=True)
def get_volume(self): return state["volume"]
def model_ready(self): return state["model_ready"]
def get_record(self): return state["record"]
def set_record(self, value):
state["record"] = bool(value)
state["need_ui_update"] = True
# — HTML front-end —
HTML = """<!DOCTYPE html>
<html>
<head>
<style>
html, body {
height: 100%; width: 100%;
margin: 0; padding: 0;
background: #1a1208;
color: #e8dbc2;
font-family: sans-serif;
overflow: hidden;
}
#container {
width: 100%; height: 100%;
box-sizing: border-box;
display: flex;
flex-direction: column;
height: 100vh;
}
#prompt {
width: calc(100% - 24px);
height: calc(100% - 64px);
min-height: 44px;
font-size: 13px;
margin: 12px 12px 0 12px;
background: #e8dbc2;
color: #2c2113;
border: none;
border-radius: 4px;
resize: none;
padding: 5px;
box-sizing: border-box;
}
#prompt::placeholder { color: #a08663; }
.controls {
width: 100%;
box-sizing: border-box;
position: absolute;
left: 0; bottom: 8px;
display: flex;
justify-content: center;
align-items: flex-end;
pointer-events: auto;
}
.inner-controls {
display: grid;
grid-template-columns: 38px 46px 68px;
gap: 10px;
align-items: center;
justify-items: center;
background: none;
padding: 0 0 0 0;
}
.iconbtn {
display: flex;
align-items: center;
justify-content: center;
width: 38px; height: 38px;
background: none;
border: none;
outline: none;
cursor: pointer;
padding: 0;
margin: 0;
}
#record-dot {
width: 18px; height: 18px;
border-radius: 50%;
background: #c94040;
border: 2px solid #ac2323;
box-shadow: 0 0 2px #ac2323;
transition: background 0.2s, border-color 0.2s;
}
#record-dot.off {
background: #2c2113;
border-color: #665050;
box-shadow: none;
}
#record:focus { outline: none; }
.playcenter {
display: flex;
align-items: center;
justify-content: center;
gap: 6px;
}
#spinbox {
display: flex;
align-items: center;
justify-content: center;
margin-left: 0;
}
.spinner {
width: 22px; height: 22px;
border: 4px solid #ab865b;
border-top: 4px solid #e8dbc2;
border-radius: 50%;
animation: spin .7s linear infinite;
}
@keyframes spin { 100% { transform: rotate(360deg); } }
#volume {
width: 60px;
margin-left: 6px;
vertical-align: middle;
accent-color: #e8dbc2; /* For Chrome, Edge, Firefox */
}
input[type="range"]::-webkit-slider-thumb {
margin-top: -6px;
background: #e8dbc2;
}
input[type="range"]::-webkit-slider-runnable-track {
background: #a08663;
height: 4px;
border-radius: 2px;
}
input[type="range"]::-moz-range-thumb {
background: #e8dbc2;
}
input[type="range"]::-moz-range-track {
background: #a08663;
height: 4px;
border-radius: 2px;
}
input[type="range"]::-ms-thumb {
background: #e8dbc2;
}
input[type="range"]::-ms-fill-lower, input[type="range"]::-ms-fill-upper {
background: #a08663;
border-radius: 2px;
}
input[type="range"]:focus {
outline: none;
}
#model_loading_overlay{
position:fixed;left:0;top:0;width:100vw;height:100vh;z-index:100;
background:#1a1208ee;display:flex;flex-direction:column;align-items:center;justify-content:center;
}
</style>
</head>
<body>
<div id="container">
<textarea id="prompt" rows="2" placeholder="Describe your tavern music..."></textarea>
<div class="controls" id="controls">
<div class="inner-controls">
<button id="record" class="iconbtn" title="Toggle saving clips">
<div id="record-dot"></div>
</button>
<div class="playcenter">
<button id="playpause" class="iconbtn" aria-label="Play/Pause">
<svg id="playpause-icon" width="26" height="26" viewBox="0 0 26 26"></svg>
</button>
<div id="spinbox" style="display:none"><div class="spinner"></div></div>
</div>
<input id="volume" type="range" min="0" max="100" value="80" title="Volume">
</div>
</div>
</div>
<div id="model_loading_overlay">
<div class="spinner"></div>
<div style="margin-top:16px;color:#e8dbc2;font-size:16px;">Loading AI Model…</div>
</div>
<script>
let updating = false;
const promptBox = document.getElementById('prompt');
let recordBtn = document.getElementById('record');
let recordDot = document.getElementById('record-dot');
let playpauseBtn = document.getElementById('playpause');
let playpauseIcon = document.getElementById('playpause-icon');
let spinbox = document.getElementById('spinbox');
let volumeSlider = document.getElementById('volume');
function setPlayIcon(isPlaying) {
playpauseIcon.innerHTML = "";
if(isPlaying){
playpauseIcon.innerHTML = `<rect x="4" y="4" width="6" height="18" rx="2" fill="#fff"/><rect x="16" y="4" width="6" height="18" rx="2" fill="#fff"/>`;
}else{
playpauseIcon.innerHTML = `<polygon points="6,4 22,13 6,22" fill="#fff"/>`;
}
}
function updateRecordUI(isOn) {
if(isOn){
recordDot.classList.remove('off');
} else {
recordDot.classList.add('off');
}
}
recordBtn.onclick = async function() {
let newVal = !(await pywebview.api.get_record());
await pywebview.api.set_record(newVal);
updateRecordUI(newVal);
};
promptBox.oninput = function() {
if (updating) return;
pywebview.api.set_prompt(this.value);
};
playpauseBtn.onclick = ()=>pywebview.api.playpause();
volumeSlider.oninput = ()=>!updating&&pywebview.api.set_volume(volumeSlider.value);
async function updateUI(p, s, prompt, vol, modelReady, recording){
updating=true;
playpauseBtn.style.display=s?'none':'';
spinbox.style.display=s?'':'none';
setPlayIcon(p);
if(document.activeElement !== promptBox && promptBox.value.length === 0)
promptBox.value=prompt;
volumeSlider.value=vol;
document.getElementById('model_loading_overlay').style.display = modelReady ? 'none' : 'flex';
updateRecordUI(recording);
updating=false;
}
window.onload=async()=>{
promptBox.value=await pywebview.api.get_prompt();
volumeSlider.value=await pywebview.api.get_volume();
pywebview.api.model_ready().then(ready=>{
document.getElementById('model_loading_overlay').style.display=ready?'none':'flex';
});
let rec = await pywebview.api.get_record();
updateRecordUI(rec);
setPlayIcon(false);
};
</script>
</body>
</html>"""
# — Launch GUI & threads —
def gui():
api=API()
win=webview.create_window("Tavern Generator", html=HTML,
width=350,height=200,js_api=api,resizable=False)
threading.Thread(target=gui_callback_loop, args=(win,), daemon=True).start()
threading.Thread(target=playback_worker, daemon=True).start()
threading.Thread(target=generator_worker, daemon=True).start()
webview.start(debug=False)
if __name__=="__main__":
threading.Thread(target=load_model, daemon=True).start()
gui()