Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li
2023-05-19 11:35:44 -06:00
committed by GitHub
parent 825d8892b5
commit f756799b84
14 changed files with 211 additions and 478 deletions

View File

@ -104,7 +104,8 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_dim,
self.scaling, rotary_dim=self.head_dim)
def forward(
self,