benchmark_transform/
main.rs

1/// This tool reads the generated weight.rs file that includes child trie access and replaces
2/// these DB Read and Writes with child trie specific DB weights located at rocksdb_child_trie_weights.rs
3///
4/// NOTE: This tool relies on the fact that the keys used for child trie access are not registered in the
5/// chain metadata, and hence show up in benchmarked weights as "UNKNOWN KEY". HOWEVER--storage declared using
6/// `#[storage_alias]` ALSO does not get registered with chain metadata, and hence reads to storage aliases
7/// also show up in weights as "UNKNOWN KEY". Storage aliases are commonly used in migrations; therefore, use
8/// caution when using benchmark weights to meter multi-block migrations.
9use regex::Regex;
10use std::{
11	env, fs,
12	io::{self},
13	process,
14};
15
16fn main() -> io::Result<()> {
17	// Get command line arguments
18	let args: Vec<String> = env::args().collect();
19
20	if args.len() != 3 {
21		eprintln!("Usage: {} <input_weights_file> <output_weights_file>", args[0]);
22		eprintln!("Example: {} weights.rs transformed_weights.rs", args[0]);
23		process::exit(1);
24	}
25
26	let input_file = &args[1];
27	let output_file = &args[2];
28
29	// Read the input file
30	let input = fs::read_to_string(input_file).map_err(|e| {
31		eprintln!("Error reading file '{input_file}': {e}");
32		e
33	})?;
34
35	// Check if already modified
36	if is_already_modified(&input) {
37		println!("✓ It was already modified and no need for transformation!");
38		return Ok(())
39	}
40
41	// Transform the content
42	let output = transform_weights(&input);
43
44	// Write to output file
45	fs::write(output_file, &output).map_err(|e| {
46		eprintln!("Error writing to file '{output_file}': {e}");
47		e
48	})?;
49
50	println!("✓ Transformation complete!");
51	println!("  Input:  {input_file}");
52	println!("  Output: {output_file}");
53	Ok(())
54}
55
56fn is_already_modified(content: &str) -> bool {
57	content.contains("'benchmark_transform'")
58}
59
60fn transform_weights(content: &str) -> String {
61	let mut result = content.to_string();
62
63	// Step 1: Adds the modification disclaimer
64	result = ["\n//! MODIFIED by 'benchmark_transform' tool to replace child trie storage access with their specific DB weights", &result].join("\n//!");
65
66	// Step 2: Add RocksDbWeightChild import after the PhantomData import
67	let import_pattern = "use core::marker::PhantomData;";
68	let new_import = "use core::marker::PhantomData;\nuse common_runtime::weights::rocksdb_child_trie_weights::constants::RocksDbWeightChild;";
69	result = result.replace(import_pattern, new_import);
70
71	// Step 3: Process each function in SubstrateWeight<T> and the () implementation
72	result = process_weight_implementations(&result);
73
74	result
75}
76
77fn process_weight_implementations(content: &str) -> String {
78	let mut result = String::new();
79	let lines: Vec<&str> = content.lines().collect();
80	let mut i = 0;
81
82	while i < lines.len() {
83		let line = lines[i];
84		result.push_str(line);
85		result.push('\n');
86
87		// Check if we're at a function definition
88		if line.trim().starts_with("fn ") && line.contains("-> Weight {") {
89			// Collect function body lines and storage comments
90			let mut func_lines = vec![];
91			let mut storage_lines = Vec::new();
92
93			// Look backwards for Storage comments
94			let mut j = i;
95			while j > 0 {
96				j -= 1;
97				let prev_line = lines[j];
98				if prev_line.trim().starts_with("/// Storage:") {
99					storage_lines.insert(0, prev_line);
100				} else if !prev_line.trim().starts_with("///") {
101					break;
102				}
103			}
104
105			// Collect function body
106			let mut brace_count = 0;
107			let mut in_function = false;
108
109			while i < lines.len() {
110				let current = lines[i];
111				func_lines.push(current);
112
113				if current.contains('{') {
114					in_function = true;
115					brace_count += current.matches('{').count();
116				}
117				if current.contains('}') {
118					brace_count -= current.matches('}').count();
119				}
120
121				if in_function && brace_count == 0 {
122					break;
123				}
124				i += 1;
125			}
126
127			// Count UNKNOWN keys
128			let unknown_reads = count_unknown_keys(&storage_lines, "r:");
129			let unknown_writes = count_unknown_keys(&storage_lines, "w:");
130
131			if unknown_reads > 0 || unknown_writes > 0 {
132				// Process the function body
133				let processed =
134					process_function_body(&func_lines.join("\n"), unknown_reads, unknown_writes);
135				// Remove the original function line since we already added it
136				let processed_lines: Vec<&str> = processed.lines().collect();
137				// Skip first line as it's already added
138				for pline in &processed_lines[1..] {
139					result.push_str(pline);
140					result.push('\n');
141				}
142				i += 1;
143				continue;
144			}
145		}
146
147		i += 1;
148	}
149
150	result
151}
152
153fn count_unknown_keys(storage_lines: &[&str], operation: &str) -> u64 {
154	let mut count = 0;
155	for line in storage_lines {
156		if line.contains("UNKNOWN KEY") && line.contains(operation) {
157			// Extract the operation count (e.g., "r:1" or "w:1")
158			if let Some(idx) = line.find(operation) {
159				let after = &line[idx + operation.len()..];
160				if let Some(space_idx) = after.find(|c: char| c.is_whitespace() || c == ')') {
161					if let Ok(num) = after[..space_idx].parse::<u64>() {
162						count += num;
163					}
164				}
165			}
166		}
167	}
168	count
169}
170
171fn process_function_body(func_body: &str, unknown_reads: u64, unknown_writes: u64) -> String {
172	let lines: Vec<&str> = func_body.lines().collect();
173	let mut result = Vec::new();
174	let mut i = 0;
175	let read_regex = Regex::new(r"reads\((\d+)_u64\)").expect("Should create regex.");
176	let write_regex = Regex::new(r"writes\((\d+)_u64\)").expect("Should create regex.");
177
178	while i < lines.len() {
179		let line = lines[i];
180
181		// Check for T::DbWeight::get().reads() or RocksDbWeight::get().reads()
182		if (line.contains("T::DbWeight::get().reads(") ||
183			line.contains("RocksDbWeight::get().reads(")) &&
184			unknown_reads > 0
185		{
186			// Extract the current read count
187			if let Some(caps) = read_regex.captures(line) {
188				if let Ok(current_reads) = caps[1].parse::<u64>() {
189					// Only subtract if current_reads is greater than unknown_reads
190					if current_reads >= unknown_reads {
191						let new_reads = current_reads - unknown_reads;
192						let indent = get_indent(line);
193
194						// Determine which DbWeight to use
195						let db_weight = if line.contains("T::DbWeight") {
196							"T::DbWeight"
197						} else {
198							"RocksDbWeight"
199						};
200
201						// Add the modified line
202						if new_reads > 0 {
203							result.push(format!(
204								"{indent}.saturating_add({db_weight}::get().reads({new_reads}_u64))"
205							));
206						}
207						// Add the child trie reads line
208						result.push(format!(
209							"{indent}.saturating_add(RocksDbWeightChild::get().reads({unknown_reads}_u64))"
210						));
211						i += 1;
212						continue;
213					}
214				}
215			}
216		}
217
218		// Check for T::DbWeight::get().writes() or RocksDbWeight::get().writes()
219		if (line.contains("T::DbWeight::get().writes(") ||
220			line.contains("RocksDbWeight::get().writes(")) &&
221			unknown_writes > 0
222		{
223			if let Some(caps) = write_regex.captures(line) {
224				if let Ok(current_writes) = caps[1].parse::<u64>() {
225					// Only subtract if current_writes is greater than unknown_writes
226					if current_writes >= unknown_writes {
227						let new_writes = current_writes - unknown_writes;
228						let indent = get_indent(line);
229
230						let db_weight = if line.contains("T::DbWeight") {
231							"T::DbWeight"
232						} else {
233							"RocksDbWeight"
234						};
235
236						if new_writes > 0 {
237							result.push(format!(
238								"{indent}.saturating_add({db_weight}::get().writes({new_writes}_u64))"
239							));
240						}
241						result.push(format!(
242							"{indent}.saturating_add(RocksDbWeightChild::get().writes({unknown_writes}_u64))"
243						));
244						i += 1;
245						continue;
246					}
247				}
248			}
249		}
250
251		result.push(line.to_string());
252		i += 1;
253	}
254
255	result.join("\n")
256}
257
258fn get_indent(line: &str) -> String {
259	let trimmed = line.trim_start();
260	line[..line.len() - trimmed.len()].to_string()
261}