[Bugfix] Fix ColumnParallelLinearWithLoRA slice (#11708)
Signed-off-by: ZincCat <zincchloride@outlook.com>
This commit is contained in:
@ -479,7 +479,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
# ColumnParallelLinear.
|
||||
else:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
shard_size = self.output_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
@ -490,7 +490,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
if bias is None:
|
||||
return bias
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
shard_size = self.output_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
bias = bias[start_idx:end_idx]
|
||||
|
||||
Reference in New Issue
Block a user