[Bugfix] Fix ColumnParallelLinearWithLoRA slice (#11708)

Signed-off-by: ZincCat <zincchloride@outlook.com>
This commit is contained in:
ZincCat
2025-01-03 13:02:34 -08:00
committed by GitHub
parent 80c751e7f6
commit 61fed92c7e

View File

@ -479,7 +479,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# ColumnParallelLinear.
else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
@ -490,7 +490,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
if bias is None:
return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
bias = bias[start_idx:end_idx]