[Misc] Reduce supported Punica dtypes (#4304)
This commit is contained in:
@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
|
||||
def _pretest():
|
||||
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
||||
1024, vocab_size)
|
||||
1024,
|
||||
vocab_size,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
linear.weight.data[:, vocab_size:] = 0
|
||||
logits_processor = LogitsProcessor(
|
||||
@ -445,7 +447,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
num_inputs=8 * num_loras, # * 3,
|
||||
input_size=(1, 1024),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float32,
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -494,7 +496,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
num_inputs=8 * num_loras * 3,
|
||||
input_size=(1, 1024),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float32,
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
|
||||
def create_random_linear_parallel_layer():
|
||||
if orientation == "row":
|
||||
linear = RowParallelLinear(4096, 4096, bias=False)
|
||||
linear = RowParallelLinear(4096,
|
||||
4096,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = RowParallelLinearWithLoRA(linear)
|
||||
else:
|
||||
linear = ColumnParallelLinear(4096, 4096, bias=False)
|
||||
linear = ColumnParallelLinear(4096,
|
||||
4096,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||
@ -561,7 +569,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float32,
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -600,7 +608,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float32,
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
def create_column_parallel_packed_layer():
|
||||
if repeats == 2:
|
||||
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
|
||||
bias=False)
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||
elif repeats == 3:
|
||||
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
||||
linear = QKVParallelLinear(4096,
|
||||
64,
|
||||
32,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedQKVParallelLinearWithLora(linear)
|
||||
else:
|
||||
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
||||
linear = QKVParallelLinear(4096,
|
||||
64,
|
||||
32,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = QKVParallelLinearWithLora(linear)
|
||||
|
||||
@ -676,7 +693,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float32,
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -716,7 +733,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float32,
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user