Initial commit
This commit is contained in:
34
ddpo_pytorch/stat_tracking.py
Normal file
34
ddpo_pytorch/stat_tracking.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
|
||||
class PerPromptStatTracker:
|
||||
def __init__(self, buffer_size, min_count):
|
||||
self.buffer_size = buffer_size
|
||||
self.min_count = min_count
|
||||
self.stats = {}
|
||||
|
||||
def update(self, prompts, rewards):
|
||||
unique = np.unique(prompts)
|
||||
advantages = np.empty_like(rewards)
|
||||
for prompt in unique:
|
||||
prompt_rewards = rewards[prompts == prompt]
|
||||
if prompt not in self.stats:
|
||||
self.stats[prompt] = deque(maxlen=self.buffer_size)
|
||||
self.stats[prompt].extend(prompt_rewards)
|
||||
|
||||
if len(self.stats[prompt]) < self.min_count:
|
||||
mean = np.mean(rewards)
|
||||
std = np.std(rewards) + 1e-6
|
||||
else:
|
||||
mean = np.mean(self.stats[prompt])
|
||||
std = np.std(self.stats[prompt]) + 1e-6
|
||||
advantages[prompts == prompt] = (prompt_rewards - mean) / std
|
||||
|
||||
return advantages
|
||||
|
||||
def get_stats(self):
|
||||
return {
|
||||
k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)}
|
||||
for k, v in self.stats.items()
|
||||
}
|
Reference in New Issue
Block a user