A startup’s GenAI application cost $0.42 per query at 15-second latency. At this rate, their Series A funding would last six months. The problem wasn’t the model—it was unoptimized inference. Each request loaded the model from scratch, processed tokens sequentially, and cached nothing. This is the default architecture most teams start with, and it破产.
Optimizing LLM inference requires understanding where time and money actually go.
The Performance Challenge
LLMs present unique challenges:
Massive Compute: A forward pass through a 70B parameter model requires ~140 GFLOPs. Processing 100 tokens takes billions of operations.
Memory Bandwidth Bottleneck: LLMs are memory-bound, not compute-bound. Loading 70B parameters (280GB in FP32) dominates inference time.
Sequential Dependencies: Each token depends on all previous tokens, limiting parallelization.
Variable Costs: A 10-token response costs 10x less than a 100-token response, but naive pricing ignores this.
This diagram requires JavaScript.
Enable JavaScript in your browser to use this feature.
Profiling
Before optimizing, instrument everything:
class GenAIPerformanceProfiler:
def __init__(self):
self.metrics = defaultdict(list)
self.cost_calculator = CostCalculator()
def profile_inference(self, model, inputs):
"""Profile complete inference pipeline"""
profile_data = {
'request_id': str(uuid.uuid4()),
'timestamp': datetime.now(),
'input_tokens': len(inputs['input_ids'][0]),
'timings': {},
'memory': {},
'costs': {}
}
torch.cuda.reset_peak_memory_stats()
start_memory = torch.cuda.memory_allocated()
with self.timer('model_loading'):
if not self.is_model_cached(model):
model = self.load_model(model)
with self.timer('preprocessing'):
processed_inputs = self.preprocess(inputs)
with self.timer('attention_computation'):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with self.timer('forward_pass'):
with torch.no_grad():
outputs = model.generate(
**processed_inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
return_dict_in_generate=True,
output_attentions=True,
output_hidden_states=True
)
profile_data['cuda_time_ms'] = sum(
[item.cuda_time_total for item in prof.key_averages()]
) / 1000
profile_data['memory']['peak_gb'] = (
torch.cuda.max_memory_allocated() - start_memory
) / (1024**3)
profile_data['output_tokens'] = len(outputs.sequences[0]) - len(inputs['input_ids'][0])
profile_data['tokens_per_second'] = (
profile_data['output_tokens'] /
profile_data['timings']['forward_pass']
)
profile_data['costs'] = self.cost_calculator.calculate(
compute_time=profile_data['cuda_time_ms'] / 1000,
memory_gb=profile_data['memory']['peak_gb'],
input_tokens=profile_data['input_tokens'],
output_tokens=profile_data['output_tokens']
)
profile_data['bottlenecks'] = self.identify_bottlenecks(prof)
return profile_data
Baseline metrics:
- Average latency: 12.3 seconds per request
- Throughput: 0.08 requests per second per GPU
- Memory usage: 95% of GPU memory for single request
- Cost per 1k tokens: $0.0021
Dynamic Batching
Continuous Batching Implementation
class ContinuousBatchingEngine:
def __init__(self, model, max_batch_size=32, max_wait_time_ms=50):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_time_ms = max_wait_time_ms
self.request_queue = asyncio.Queue()
self.active_sequences = {}
async def process_requests(self):
"""Main processing loop with continuous batching"""
while True:
batch = await self.collect_batch()
if batch:
await self.process_batch(batch)
self.handle_completions()
async def collect_batch(self):
"""Collect requests into optimal batches"""
batch = []
deadline = time.time() + (self.max_wait_time_ms / 1000)
while len(batch) < self.max_batch_size and time.time() < deadline:
try:
timeout = max(0, deadline - time.time())
request = await asyncio.wait_for(
self.request_queue.get(),
timeout=timeout
)
if self.is_compatible_with_batch(request, batch):
batch.append(request)
else:
await self.request_queue.put(request)
break
except asyncio.TimeoutError:
break
return batch
def is_compatible_with_batch(self, request, batch):
"""Check if request can be batched with others"""
if not batch:
return True
reference = batch[0]
if request['temperature'] != reference['temperature']:
return False
if request['max_tokens'] != reference['max_tokens']:
return False
estimated_memory = self.estimate_batch_memory(batch + [request])
if estimated_memory > self.get_available_memory():
return False
return True
async def process_batch(self, batch):
"""Process a batch of requests efficiently"""
batch_inputs = self.prepare_batch_inputs(batch)
all_finished = False
generated_tokens = {req['id']: [] for req in batch}
while not all_finished:
with torch.no_grad():
outputs = self.model(
input_ids=batch_inputs['input_ids'],
attention_mask=batch_inputs['attention_mask'],
past_key_values=batch_inputs.get('past_key_values'),
use_cache=True
)
next_tokens = self.sample_tokens(outputs.logits, batch)
for i, request in enumerate(batch):
if not request['finished']:
token = next_tokens[i]
generated_tokens[request['id']].append(token)
if (token == self.model.config.eos_token_id or
len(generated_tokens[request['id']]) >= request['max_tokens']):
request['finished'] = True
request['response'] = generated_tokens[request['id']]
batch = [req for req in batch if not req['finished']]
all_finished = len(batch) == 0
if not all_finished:
batch_inputs = self.update_batch_inputs(
batch_inputs,
next_tokens,
outputs.past_key_values
)
Padding Optimization
class OptimizedPaddingStrategy:
def __init__(self):
self.padding_stats = defaultdict(list)
def dynamic_bucket_padding(self, sequences):
"""Group sequences into buckets for minimal padding"""
buckets = self.compute_optimal_buckets(sequences)
bucketed_sequences = defaultdict(list)
for seq in sequences:
bucket = self.find_bucket(len(seq), buckets)
bucketed_sequences[bucket].append(seq)
padded_batches = []
for bucket_size, bucket_sequences in bucketed_sequences.items():
padded = self.pad_sequences(bucket_sequences, bucket_size)
padded_batches.append({
'sequences': padded,
'bucket_size': bucket_size,
'efficiency': self.calculate_padding_efficiency(
bucket_sequences,
bucket_size
)
})
return padded_batches
def compute_optimal_buckets(self, sequences):
"""Compute optimal bucket boundaries"""
lengths = [len(seq) for seq in sequences]
if len(set(lengths)) > 10:
kmeans = KMeans(n_clusters=min(8, len(set(lengths))))
kmeans.fit(np.array(lengths).reshape(-1, 1))
buckets = []
for label in range(kmeans.n_clusters):
cluster_lengths = [l for i, l in enumerate(lengths)
if kmeans.labels_[i] == label]
if cluster_lengths:
buckets.append(max(cluster_lengths))
else:
buckets = sorted(set(lengths))
return sorted(buckets)
def left_padding_for_generation(self, sequences):
"""Left-pad sequences for efficient generation"""
max_length = max(len(seq) for seq in sequences)
padded_sequences = []
attention_masks = []
for seq in sequences:
padding_length = max_length - len(seq)
padded = [self.pad_token_id] * padding_length + seq
mask = [0] * padding_length + [1] * len(seq)
padded_sequences.append(padded)
attention_masks.append(mask)
return {
'input_ids': torch.tensor(padded_sequences),
'attention_mask': torch.tensor(attention_masks),
'padding_efficiency': self.calculate_padding_efficiency(
sequences,
max_length
)
}
Caching
KV Cache Management
class KVCacheManager:
def __init__(self, cache_size_gb=10):
self.cache_size_gb = cache_size_gb
self.cache = {}
self.access_times = {}
self.cache_hits = 0
self.cache_misses = 0
def get_or_compute_kv_cache(self, input_ids, model, layer_idx):
"""Retrieve cached KV values or compute them"""
cache_key = self.compute_cache_key(input_ids, layer_idx)
if cache_key in self.cache:
self.cache_hits += 1
self.access_times[cache_key] = time.time()
return self.cache[cache_key]
self.cache_misses += 1
with torch.no_grad():
hidden_states = self.get_hidden_states(input_ids, model, layer_idx)
layer = model.model.layers[layer_idx]
key_states = layer.self_attn.k_proj(hidden_states)
value_states = layer.self_attn.v_proj(hidden_states)
compressed_kv = self.compress_kv_states(key_states, value_states)
self.evict_if_needed(compressed_kv.nbytes)
self.cache[cache_key] = compressed_kv
self.access_times[cache_key] = time.time()
return compressed_kv
def compress_kv_states(self, key_states, value_states):
"""Compress KV states for efficient caching"""
key_scale = key_states.abs().max() / 127
value_scale = value_states.abs().max() / 127
key_int8 = (key_states / key_scale).round().to(torch.int8)
value_int8 = (value_states / value_scale).round().to(torch.int8)
return {
'keys': key_int8,
'values': value_int8,
'key_scale': key_scale,
'value_scale': value_scale,
'shape': key_states.shape,
'dtype': key_states.dtype
}
Prompt Cache
class PromptCacheSystem:
def __init__(self, embedding_model):
self.embedding_model = embedding_model
self.prompt_cache = {}
self.semantic_index = faiss.IndexFlatL2(768)
self.prompt_to_index = {}
def get_cached_prompt_state(self, prompt, model):
"""Retrieve or compute prompt hidden states"""
prompt_hash = hashlib.sha256(prompt.encode()).hexdigest()
if prompt_hash in self.prompt_cache:
return self.prompt_cache[prompt_hash]
similar_prompt = self.find_similar_prompt(prompt)
if similar_prompt:
cached_state = self.prompt_cache[similar_prompt['hash']]
additional_state = self.compute_incremental_state(
similar_prompt['prompt'],
prompt,
cached_state,
model
)
combined_state = self.combine_states(
cached_state,
additional_state
)
self.cache_prompt_state(prompt, combined_state)
return combined_state
prompt_state = self.compute_prompt_state(prompt, model)
self.cache_prompt_state(prompt, prompt_state)
return prompt_state
def find_similar_prompt(self, prompt, threshold=0.95):
"""Find semantically similar cached prompt"""
embedding = self.embedding_model.encode(prompt)
if self.semantic_index.ntotal > 0:
distances, indices = self.semantic_index.search(
embedding.reshape(1, -1),
k=5
)
for dist, idx in zip(distances[0], indices[0]):
similarity = 1 - (dist / 2)
if similarity > threshold:
prompt_hash = list(self.prompt_to_index.keys())[
list(self.prompt_to_index.values()).index(idx)
]
return {
'hash': prompt_hash,
'prompt': self.get_prompt_by_hash(prompt_hash),
'similarity': similarity
}
return None
Quantization
Mixed Precision Quantization
class MixedPrecisionQuantizer:
def __init__(self, sensitivity_threshold=0.01):
self.sensitivity_threshold = sensitivity_threshold
self.layer_sensitivities = {}
def analyze_layer_sensitivity(self, model, calibration_data):
"""Analyze which layers are sensitive to quantization"""
original_outputs = []
with torch.no_grad():
for batch in calibration_data:
output = model(**batch)
original_outputs.append(output.logits)
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
original_weight = module.weight.data.clone()
quantized_weight = self.quantize_tensor(
module.weight.data,
bits=8
)
module.weight.data = quantized_weight
quantized_outputs = []
with torch.no_grad():
for batch in calibration_data:
output = model(**batch)
quantized_outputs.append(output.logits)
sensitivity = self.calculate_sensitivity(
original_outputs,
quantized_outputs
)
self.layer_sensitivities[name] = sensitivity
module.weight.data = original_weight
return self.layer_sensitivities
def apply_mixed_precision_quantization(self, model):
"""Apply different quantization levels based on sensitivity"""
quantization_map = {}
for name, sensitivity in self.layer_sensitivities.items():
if sensitivity < self.sensitivity_threshold:
bits = 4
elif sensitivity < self.sensitivity_threshold * 2:
bits = 8
else:
bits = 16
quantization_map[name] = bits
for name, module in model.named_modules():
if name in quantization_map:
bits = quantization_map[name]
if bits < 16:
self.quantize_module(module, bits)
return quantization_map
GPTQ Quantization
class GPTQQuantizer:
"""Implements GPTQ (Generative Pre-trained Transformer Quantization)"""
def __init__(self, bits=4, group_size=128, damp_percent=0.01):
self.bits = bits
self.group_size = group_size
self.damp_percent = damp_percent
def quantize_model(self, model, calibration_loader):
"""Quantize model using GPTQ algorithm"""
model.eval()
hessians = self.collect_hessians(model, calibration_loader)
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
print(f"Quantizing {name}...")
H = hessians[name]
H_damp = self.add_dampening(H)
try:
L = torch.linalg.cholesky(H_damp)
L_inv = torch.inverse(L)
except:
print(f"Warning: Cholesky failed for {name}, using SVD")
L_inv = self.compute_pseudo_inverse(H_damp)
W_quant = self.quantize_weights_gptq(
module.weight.data,
L_inv
)
self.replace_with_quantized(module, W_quant)
return model
def quantize_weights_gptq(self, weights, L_inv):
"""GPTQ weight quantization"""
W = weights.clone()
Q = torch.zeros_like(W)
for i in range(0, W.shape[1], self.group_size):
group_end = min(i + self.group_size, W.shape[1])
group_weights = W[:, i:group_end]
scale = self.find_optimal_scale(group_weights, self.bits)
for j in range(i, group_end):
w = W[:, j]
q = torch.round(w / scale).clamp(
-(2**(self.bits-1)),
2**(self.bits-1) - 1
)
Q[:, j] = q
if j < W.shape[1] - 1:
error = (w - q * scale)
W[:, j+1:] -= error.unsqueeze(1) @ L_inv[j, j+1:].unsqueeze(0)
return {
'quantized_weights': Q.to(torch.int8),
'scales': scale,
'group_size': self.group_size
}
Results
This diagram requires JavaScript.
Enable JavaScript in your browser to use this feature.
- P50 Latency: 0.8 seconds (93% reduction)
- P95 Latency: 1.5 seconds (90% reduction)
- P99 Latency: 2.1 seconds (86% reduction)
- Throughput: 43.75x improvement
- Cost: 88% reduction
- Model Quality: <2% degradation in perplexity
Decision Rules
Optimize inference when:
- Cost per query threatens unit economics
- Latency impacts user experience
- Throughput limits product adoption
- GPU budget is constrained
Key principles:
- Profile first: you can’t optimize what you don’t measure
- Batching provides the biggest bang for the buck
- LLMs are memory-bound: optimizations reducing memory movement have outsized impact
- Not all requests need maximum quality: route to appropriate model variants
- Caching is complex but crucial—requires understanding access patterns